Merge branch 'main' into litellm_aws_kms_fixes
|
@ -65,6 +65,7 @@ jobs:
|
||||||
pip install "pydantic==2.7.1"
|
pip install "pydantic==2.7.1"
|
||||||
pip install "diskcache==5.6.1"
|
pip install "diskcache==5.6.1"
|
||||||
pip install "Pillow==10.3.0"
|
pip install "Pillow==10.3.0"
|
||||||
|
pip install "ijson==3.2.3"
|
||||||
- save_cache:
|
- save_cache:
|
||||||
paths:
|
paths:
|
||||||
- ./venv
|
- ./venv
|
||||||
|
@ -126,6 +127,7 @@ jobs:
|
||||||
pip install jinja2
|
pip install jinja2
|
||||||
pip install tokenizers
|
pip install tokenizers
|
||||||
pip install openai
|
pip install openai
|
||||||
|
pip install ijson
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
command: |
|
command: |
|
||||||
|
@ -180,6 +182,7 @@ jobs:
|
||||||
pip install numpydoc
|
pip install numpydoc
|
||||||
pip install prisma
|
pip install prisma
|
||||||
pip install fastapi
|
pip install fastapi
|
||||||
|
pip install ijson
|
||||||
pip install "httpx==0.24.1"
|
pip install "httpx==0.24.1"
|
||||||
pip install "gunicorn==21.2.0"
|
pip install "gunicorn==21.2.0"
|
||||||
pip install "anyio==3.7.1"
|
pip install "anyio==3.7.1"
|
||||||
|
@ -202,6 +205,7 @@ jobs:
|
||||||
-e REDIS_PORT=$REDIS_PORT \
|
-e REDIS_PORT=$REDIS_PORT \
|
||||||
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
||||||
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
||||||
|
-e MISTRAL_API_KEY=$MISTRAL_API_KEY \
|
||||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
|
|
10
.github/dependabot.yaml
vendored
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
groups:
|
||||||
|
github-actions:
|
||||||
|
patterns:
|
||||||
|
- "*"
|
22
.github/workflows/ghcr_deploy.yml
vendored
|
@ -25,6 +25,11 @@ jobs:
|
||||||
if: github.repository == 'BerriAI/litellm'
|
if: github.repository == 'BerriAI/litellm'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
-
|
||||||
|
name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.inputs.commit_hash }}
|
||||||
-
|
-
|
||||||
name: Set up QEMU
|
name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
|
@ -41,12 +46,14 @@ jobs:
|
||||||
name: Build and push
|
name: Build and push
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
|
context: .
|
||||||
push: true
|
push: true
|
||||||
tags: litellm/litellm:${{ github.event.inputs.tag || 'latest' }}
|
tags: litellm/litellm:${{ github.event.inputs.tag || 'latest' }}
|
||||||
-
|
-
|
||||||
name: Build and push litellm-database image
|
name: Build and push litellm-database image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
|
context: .
|
||||||
push: true
|
push: true
|
||||||
file: Dockerfile.database
|
file: Dockerfile.database
|
||||||
tags: litellm/litellm-database:${{ github.event.inputs.tag || 'latest' }}
|
tags: litellm/litellm-database:${{ github.event.inputs.tag || 'latest' }}
|
||||||
|
@ -54,6 +61,7 @@ jobs:
|
||||||
name: Build and push litellm-spend-logs image
|
name: Build and push litellm-spend-logs image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
|
context: .
|
||||||
push: true
|
push: true
|
||||||
file: ./litellm-js/spend-logs/Dockerfile
|
file: ./litellm-js/spend-logs/Dockerfile
|
||||||
tags: litellm/litellm-spend_logs:${{ github.event.inputs.tag || 'latest' }}
|
tags: litellm/litellm-spend_logs:${{ github.event.inputs.tag || 'latest' }}
|
||||||
|
@ -68,6 +76,8 @@ jobs:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.inputs.commit_hash }}
|
||||||
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
|
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
|
||||||
- name: Log in to the Container registry
|
- name: Log in to the Container registry
|
||||||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
||||||
|
@ -92,7 +102,7 @@ jobs:
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
uses: docker/build-push-action@4976231911ebf5f32aad765192d35f942aa48cb8
|
uses: docker/build-push-action@4976231911ebf5f32aad765192d35f942aa48cb8
|
||||||
with:
|
with:
|
||||||
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
|
context: .
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
@ -106,6 +116,8 @@ jobs:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.inputs.commit_hash }}
|
||||||
|
|
||||||
- name: Log in to the Container registry
|
- name: Log in to the Container registry
|
||||||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
||||||
|
@ -128,7 +140,7 @@ jobs:
|
||||||
- name: Build and push Database Docker image
|
- name: Build and push Database Docker image
|
||||||
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
|
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
|
||||||
with:
|
with:
|
||||||
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
|
context: .
|
||||||
file: Dockerfile.database
|
file: Dockerfile.database
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
|
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||||
|
@ -143,6 +155,8 @@ jobs:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.inputs.commit_hash }}
|
||||||
|
|
||||||
- name: Log in to the Container registry
|
- name: Log in to the Container registry
|
||||||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
||||||
|
@ -165,7 +179,7 @@ jobs:
|
||||||
- name: Build and push Database Docker image
|
- name: Build and push Database Docker image
|
||||||
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
|
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
|
||||||
with:
|
with:
|
||||||
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
|
context: .
|
||||||
file: ./litellm-js/spend-logs/Dockerfile
|
file: ./litellm-js/spend-logs/Dockerfile
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
|
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||||
|
@ -176,6 +190,8 @@ jobs:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.inputs.commit_hash }}
|
||||||
|
|
||||||
- name: Log in to the Container registry
|
- name: Log in to the Container registry
|
||||||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
||||||
|
|
|
@ -1,4 +1,19 @@
|
||||||
repos:
|
repos:
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
name: mypy
|
||||||
|
entry: python3 -m mypy --ignore-missing-imports
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
files: ^litellm/
|
||||||
|
- id: isort
|
||||||
|
name: isort
|
||||||
|
entry: isort
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
files: litellm/.*\.py
|
||||||
|
exclude: ^litellm/__init__.py$
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 24.2.0
|
rev: 24.2.0
|
||||||
hooks:
|
hooks:
|
||||||
|
@ -16,11 +31,10 @@ repos:
|
||||||
name: Check if files match
|
name: Check if files match
|
||||||
entry: python3 ci_cd/check_files_match.py
|
entry: python3 ci_cd/check_files_match.py
|
||||||
language: system
|
language: system
|
||||||
- repo: local
|
# - id: check-file-length
|
||||||
hooks:
|
# name: Check file length
|
||||||
- id: mypy
|
# entry: python check_file_length.py
|
||||||
name: mypy
|
# args: ["10000"] # set your desired maximum number of lines
|
||||||
entry: python3 -m mypy --ignore-missing-imports
|
# language: python
|
||||||
language: system
|
# files: litellm/.*\.py
|
||||||
types: [python]
|
# exclude: ^litellm/tests/
|
||||||
files: ^litellm/
|
|
28
check_file_length.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_length(max_lines, filenames):
|
||||||
|
bad_files = []
|
||||||
|
for filename in filenames:
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
if len(lines) > max_lines:
|
||||||
|
bad_files.append((filename, len(lines)))
|
||||||
|
return bad_files
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
max_lines = int(sys.argv[1])
|
||||||
|
filenames = sys.argv[2:]
|
||||||
|
|
||||||
|
bad_files = check_file_length(max_lines, filenames)
|
||||||
|
if bad_files:
|
||||||
|
bad_files.sort(
|
||||||
|
key=lambda x: x[1], reverse=True
|
||||||
|
) # Sort files by length in descending order
|
||||||
|
for filename, length in bad_files:
|
||||||
|
print(f"{filename}: {length} lines")
|
||||||
|
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
sys.exit(0)
|
|
@ -162,7 +162,7 @@ def completion(
|
||||||
|
|
||||||
- `function`: *object* - Required.
|
- `function`: *object* - Required.
|
||||||
|
|
||||||
- `tool_choice`: *string or object (optional)* - Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces the model to call that function.
|
- `tool_choice`: *string or object (optional)* - Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via `{"type: "function", "function": {"name": "my_function"}}` forces the model to call that function.
|
||||||
|
|
||||||
- `none` is the default when no functions are present. `auto` is the default if functions are present.
|
- `none` is the default when no functions are present. `auto` is the default if functions are present.
|
||||||
|
|
||||||
|
|
|
@ -1,90 +0,0 @@
|
||||||
import Image from '@theme/IdealImage';
|
|
||||||
import QueryParamReader from '../../src/components/queryParamReader.js'
|
|
||||||
|
|
||||||
# [Beta] Monitor Logs in Production
|
|
||||||
|
|
||||||
:::note
|
|
||||||
|
|
||||||
This is in beta. Expect frequent updates, as we improve based on your feedback.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
LiteLLM provides an integration to let you monitor logs in production.
|
|
||||||
|
|
||||||
👉 Jump to our sample LiteLLM Dashboard: https://admin.litellm.ai/
|
|
||||||
|
|
||||||
|
|
||||||
<Image img={require('../../img/alt_dashboard.png')} alt="Dashboard" />
|
|
||||||
|
|
||||||
## Debug your first logs
|
|
||||||
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_OpenAI.ipynb">
|
|
||||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
|
||||||
</a>
|
|
||||||
|
|
||||||
|
|
||||||
### 1. Get your LiteLLM Token
|
|
||||||
|
|
||||||
Go to [admin.litellm.ai](https://admin.litellm.ai/) and copy the code snippet with your unique token
|
|
||||||
|
|
||||||
<Image img={require('../../img/hosted_debugger_usage_page.png')} alt="Usage" />
|
|
||||||
|
|
||||||
### 2. Set up your environment
|
|
||||||
|
|
||||||
**Add it to your .env**
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.env["LITELLM_TOKEN"] = "e24c4c06-d027-4c30-9e78-18bc3a50aebb" # replace with your unique token
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
**Turn on LiteLLM Client**
|
|
||||||
```python
|
|
||||||
import litellm
|
|
||||||
litellm.client = True
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Make a normal `completion()` call
|
|
||||||
```python
|
|
||||||
import litellm
|
|
||||||
from litellm import completion
|
|
||||||
import os
|
|
||||||
|
|
||||||
# set env variables
|
|
||||||
os.environ["LITELLM_TOKEN"] = "e24c4c06-d027-4c30-9e78-18bc3a50aebb" # replace with your unique token
|
|
||||||
os.environ["OPENAI_API_KEY"] = "openai key"
|
|
||||||
|
|
||||||
litellm.use_client = True # enable logging dashboard
|
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
|
||||||
|
|
||||||
# openai call
|
|
||||||
response = completion(model="gpt-3.5-turbo", messages=messages)
|
|
||||||
```
|
|
||||||
|
|
||||||
Your `completion()` call print with a link to your session dashboard (https://admin.litellm.ai/<your_unique_token>)
|
|
||||||
|
|
||||||
In the above case it would be: [`admin.litellm.ai/e24c4c06-d027-4c30-9e78-18bc3a50aebb`](https://admin.litellm.ai/e24c4c06-d027-4c30-9e78-18bc3a50aebb)
|
|
||||||
|
|
||||||
Click on your personal dashboard link. Here's how you can find it 👇
|
|
||||||
|
|
||||||
<Image img={require('../../img/dash_output.png')} alt="Dashboard" />
|
|
||||||
|
|
||||||
[👋 Tell us if you need better privacy controls](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version?month=2023-08)
|
|
||||||
|
|
||||||
### 3. Review request log
|
|
||||||
|
|
||||||
Oh! Looks like our request was made successfully. Let's click on it and see exactly what got sent to the LLM provider.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Ah! So we can see that this request was made to a **Baseten** (see litellm_params > custom_llm_provider) for a model with ID - **7qQNLDB** (see model). The message sent was - `"Hey, how's it going?"` and the response received was - `"As an AI language model, I don't have feelings or emotions, but I can assist you with your queries. How can I assist you today?"`
|
|
||||||
|
|
||||||
<Image img={require('../../img/dashboard_log.png')} alt="Dashboard Log Row" />
|
|
||||||
|
|
||||||
:::info
|
|
||||||
|
|
||||||
🎉 Congratulations! You've successfully debugger your first log!
|
|
||||||
|
|
||||||
:::
|
|
|
@ -2,6 +2,15 @@ import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# Athina
|
# Athina
|
||||||
|
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
[Athina](https://athina.ai/) is an evaluation framework and production monitoring platform for your LLM-powered app. Athina is designed to enhance the performance and reliability of AI applications through real-time monitoring, granular analytics, and plug-and-play evaluations.
|
[Athina](https://athina.ai/) is an evaluation framework and production monitoring platform for your LLM-powered app. Athina is designed to enhance the performance and reliability of AI applications through real-time monitoring, granular analytics, and plug-and-play evaluations.
|
||||||
|
|
||||||
<Image img={require('../../img/athina_dashboard.png')} />
|
<Image img={require('../../img/athina_dashboard.png')} />
|
||||||
|
|
|
@ -1,5 +1,14 @@
|
||||||
# Greenscale - Track LLM Spend and Responsible Usage
|
# Greenscale - Track LLM Spend and Responsible Usage
|
||||||
|
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
[Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII).
|
[Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII).
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
|
@ -1,4 +1,13 @@
|
||||||
# Helicone Tutorial
|
# Helicone Tutorial
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
[Helicone](https://helicone.ai/) is an open source observability platform that proxies your OpenAI traffic and provides you key insights into your spend, latency and usage.
|
[Helicone](https://helicone.ai/) is an open source observability platform that proxies your OpenAI traffic and provides you key insights into your spend, latency and usage.
|
||||||
|
|
||||||
## Use Helicone to log requests across all LLM Providers (OpenAI, Azure, Anthropic, Cohere, Replicate, PaLM)
|
## Use Helicone to log requests across all LLM Providers (OpenAI, Azure, Anthropic, Cohere, Replicate, PaLM)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import Image from '@theme/IdealImage';
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# Langfuse - Logging LLM Input/Output
|
# 🔥 Langfuse - Logging LLM Input/Output
|
||||||
|
|
||||||
LangFuse is open Source Observability & Analytics for LLM Apps
|
LangFuse is open Source Observability & Analytics for LLM Apps
|
||||||
Detailed production traces and a granular view on quality, cost and latency
|
Detailed production traces and a granular view on quality, cost and latency
|
||||||
|
@ -122,10 +122,12 @@ response = completion(
|
||||||
metadata={
|
metadata={
|
||||||
"generation_name": "ishaan-test-generation", # set langfuse Generation Name
|
"generation_name": "ishaan-test-generation", # set langfuse Generation Name
|
||||||
"generation_id": "gen-id22", # set langfuse Generation ID
|
"generation_id": "gen-id22", # set langfuse Generation ID
|
||||||
|
"parent_observation_id": "obs-id9" # set langfuse Parent Observation ID
|
||||||
"version": "test-generation-version" # set langfuse Generation Version
|
"version": "test-generation-version" # set langfuse Generation Version
|
||||||
"trace_user_id": "user-id2", # set langfuse Trace User ID
|
"trace_user_id": "user-id2", # set langfuse Trace User ID
|
||||||
"session_id": "session-1", # set langfuse Session ID
|
"session_id": "session-1", # set langfuse Session ID
|
||||||
"tags": ["tag1", "tag2"], # set langfuse Tags
|
"tags": ["tag1", "tag2"], # set langfuse Tags
|
||||||
|
"trace_name": "new-trace-name" # set langfuse Trace Name
|
||||||
"trace_id": "trace-id22", # set langfuse Trace ID
|
"trace_id": "trace-id22", # set langfuse Trace ID
|
||||||
"trace_metadata": {"key": "value"}, # set langfuse Trace Metadata
|
"trace_metadata": {"key": "value"}, # set langfuse Trace Metadata
|
||||||
"trace_version": "test-trace-version", # set langfuse Trace Version (if not set, defaults to Generation Version)
|
"trace_version": "test-trace-version", # set langfuse Trace Version (if not set, defaults to Generation Version)
|
||||||
|
@ -147,9 +149,10 @@ print(response)
|
||||||
You can also pass `metadata` as part of the request header with a `langfuse_*` prefix:
|
You can also pass `metadata` as part of the request header with a `langfuse_*` prefix:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
curl --location --request POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
--header 'langfuse_trace_id: trace-id22' \
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'langfuse_trace_id: trace-id2' \
|
||||||
--header 'langfuse_trace_user_id: user-id2' \
|
--header 'langfuse_trace_user_id: user-id2' \
|
||||||
--header 'langfuse_trace_metadata: {"key":"value"}' \
|
--header 'langfuse_trace_metadata: {"key":"value"}' \
|
||||||
--data '{
|
--data '{
|
||||||
|
@ -190,9 +193,10 @@ The following parameters can be updated on a continuation of a trace by passing
|
||||||
|
|
||||||
#### Generation Specific Parameters
|
#### Generation Specific Parameters
|
||||||
|
|
||||||
* `generation_id` - Identifier for the generation, auto-generated by default
|
* `generation_id` - Identifier for the generation, auto-generated by default
|
||||||
* `generation_name` - Identifier for the generation, auto-generated by default
|
* `generation_name` - Identifier for the generation, auto-generated by default
|
||||||
* `prompt` - Langfuse prompt object used for the generation, defaults to None
|
* `parent_observation_id` - Identifier for the parent observation, defaults to `None`
|
||||||
|
* `prompt` - Langfuse prompt object used for the generation, defaults to `None`
|
||||||
|
|
||||||
Any other key value pairs passed into the metadata not listed in the above spec for a `litellm` completion will be added as a metadata key value pair for the generation.
|
Any other key value pairs passed into the metadata not listed in the above spec for a `litellm` completion will be added as a metadata key value pair for the generation.
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,16 @@
|
||||||
import Image from '@theme/IdealImage';
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# Langsmith - Logging LLM Input/Output
|
# Langsmith - Logging LLM Input/Output
|
||||||
|
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
An all-in-one developer platform for every step of the application lifecycle
|
An all-in-one developer platform for every step of the application lifecycle
|
||||||
https://smith.langchain.com/
|
https://smith.langchain.com/
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import Image from '@theme/IdealImage';
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# Logfire - Logging LLM Input/Output
|
# 🔥 Logfire - Logging LLM Input/Output
|
||||||
|
|
||||||
Logfire is open Source Observability & Analytics for LLM Apps
|
Logfire is open Source Observability & Analytics for LLM Apps
|
||||||
Detailed production traces and a granular view on quality, cost and latency
|
Detailed production traces and a granular view on quality, cost and latency
|
||||||
|
@ -14,10 +14,14 @@ join our [discord](https://discord.gg/wuPM9dRgDw)
|
||||||
|
|
||||||
## Pre-Requisites
|
## Pre-Requisites
|
||||||
|
|
||||||
Ensure you have run `pip install logfire` for this integration
|
Ensure you have installed the following packages to use this integration
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install logfire litellm
|
pip install litellm
|
||||||
|
|
||||||
|
pip install opentelemetry-api==1.25.0
|
||||||
|
pip install opentelemetry-sdk==1.25.0
|
||||||
|
pip install opentelemetry-exporter-otlp==1.25.0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
@ -25,8 +29,7 @@ pip install logfire litellm
|
||||||
Get your Logfire token from [Logfire](https://logfire.pydantic.dev/)
|
Get your Logfire token from [Logfire](https://logfire.pydantic.dev/)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
litellm.success_callback = ["logfire"]
|
litellm.callbacks = ["logfire"]
|
||||||
litellm.failure_callback = ["logfire"] # logs errors to logfire
|
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
# Lunary - Logging and tracing LLM input/output
|
# Lunary - Logging and tracing LLM input/output
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
[Lunary](https://lunary.ai/) is an open-source AI developer platform providing observability, prompt management, and evaluation tools for AI developers.
|
[Lunary](https://lunary.ai/) is an open-source AI developer platform providing observability, prompt management, and evaluation tools for AI developers.
|
||||||
|
|
||||||
<video controls width='900' >
|
<video controls width='900' >
|
||||||
|
|
|
@ -1,5 +1,16 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# Promptlayer Tutorial
|
# Promptlayer Tutorial
|
||||||
|
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
Promptlayer is a platform for prompt engineers. Log OpenAI requests. Search usage history. Track performance. Visually manage prompt templates.
|
Promptlayer is a platform for prompt engineers. Log OpenAI requests. Search usage history. Track performance. Visually manage prompt templates.
|
||||||
|
|
||||||
<Image img={require('../../img/promptlayer.png')} />
|
<Image img={require('../../img/promptlayer.png')} />
|
||||||
|
|
46
docs/my-website/docs/observability/raw_request_response.md
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
|
# Raw Request/Response Logging
|
||||||
|
|
||||||
|
See the raw request/response sent by LiteLLM in your logging provider (OTEL/Langfuse/etc.).
|
||||||
|
|
||||||
|
**on SDK**
|
||||||
|
```python
|
||||||
|
# pip install langfuse
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
|
||||||
|
# log raw request/response
|
||||||
|
litellm.log_raw_request_response = True
|
||||||
|
|
||||||
|
# from https://cloud.langfuse.com/
|
||||||
|
os.environ["LANGFUSE_PUBLIC_KEY"] = ""
|
||||||
|
os.environ["LANGFUSE_SECRET_KEY"] = ""
|
||||||
|
# Optional, defaults to https://cloud.langfuse.com
|
||||||
|
os.environ["LANGFUSE_HOST"] # optional
|
||||||
|
|
||||||
|
# LLM API Keys
|
||||||
|
os.environ['OPENAI_API_KEY']=""
|
||||||
|
|
||||||
|
# set langfuse as a callback, litellm will send the data to langfuse
|
||||||
|
litellm.success_callback = ["langfuse"]
|
||||||
|
|
||||||
|
# openai call
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hi 👋 - i'm openai"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**on Proxy**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
log_raw_request_response: True
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Log**
|
||||||
|
|
||||||
|
<Image img={require('../../img/raw_request_log.png')}/>
|
|
@ -1,5 +1,14 @@
|
||||||
import Image from '@theme/IdealImage';
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
# Sentry - Log LLM Exceptions
|
# Sentry - Log LLM Exceptions
|
||||||
[Sentry](https://sentry.io/) provides error monitoring for production. LiteLLM can add breadcrumbs and send exceptions to Sentry with this integration
|
[Sentry](https://sentry.io/) provides error monitoring for production. LiteLLM can add breadcrumbs and send exceptions to Sentry with this integration
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,12 @@
|
||||||
# Supabase Tutorial
|
# Supabase Tutorial
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
[Supabase](https://supabase.com/) is an open source Firebase alternative.
|
[Supabase](https://supabase.com/) is an open source Firebase alternative.
|
||||||
Start your project with a Postgres database, Authentication, instant APIs, Edge Functions, Realtime subscriptions, Storage, and Vector embeddings.
|
Start your project with a Postgres database, Authentication, instant APIs, Edge Functions, Realtime subscriptions, Storage, and Vector embeddings.
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,16 @@
|
||||||
import Image from '@theme/IdealImage';
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# Weights & Biases - Logging LLM Input/Output
|
# Weights & Biases - Logging LLM Input/Output
|
||||||
|
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
This is community maintained, Please make an issue if you run into a bug
|
||||||
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
Weights & Biases helps AI developers build better models faster https://wandb.ai
|
Weights & Biases helps AI developers build better models faster https://wandb.ai
|
||||||
|
|
||||||
<Image img={require('../../img/wandb.png')} />
|
<Image img={require('../../img/wandb.png')} />
|
||||||
|
|
|
@ -11,7 +11,7 @@ LiteLLM supports
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
||||||
Anthropic API fails requests when `max_tokens` are not passed. Due to this litellm passes `max_tokens=4096` when no `max_tokens` are passed
|
Anthropic API fails requests when `max_tokens` are not passed. Due to this litellm passes `max_tokens=4096` when no `max_tokens` are passed.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
@ -229,17 +229,6 @@ assert isinstance(
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Setting `anthropic-beta` Header in Requests
|
|
||||||
|
|
||||||
Pass the the `extra_headers` param to litellm, All headers will be forwarded to Anthropic API
|
|
||||||
|
|
||||||
```python
|
|
||||||
response = completion(
|
|
||||||
model="anthropic/claude-3-opus-20240229",
|
|
||||||
messages=messages,
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Forcing Anthropic Tool Use
|
### Forcing Anthropic Tool Use
|
||||||
|
|
||||||
|
|
|
@ -68,6 +68,7 @@ response = litellm.completion(
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|------------------|----------------------------------------|
|
|------------------|----------------------------------------|
|
||||||
|
| gpt-4o | `completion('azure/<your deployment name>', messages)` |
|
||||||
| gpt-4 | `completion('azure/<your deployment name>', messages)` |
|
| gpt-4 | `completion('azure/<your deployment name>', messages)` |
|
||||||
| gpt-4-0314 | `completion('azure/<your deployment name>', messages)` |
|
| gpt-4-0314 | `completion('azure/<your deployment name>', messages)` |
|
||||||
| gpt-4-0613 | `completion('azure/<your deployment name>', messages)` |
|
| gpt-4-0613 | `completion('azure/<your deployment name>', messages)` |
|
||||||
|
@ -85,7 +86,8 @@ response = litellm.completion(
|
||||||
## Azure OpenAI Vision Models
|
## Azure OpenAI Vision Models
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|-----------------------|-----------------------------------------------------------------|
|
|-----------------------|-----------------------------------------------------------------|
|
||||||
| gpt-4-vision | `response = completion(model="azure/<your deployment name>", messages=messages)` |
|
| gpt-4-vision | `completion(model="azure/<your deployment name>", messages=messages)` |
|
||||||
|
| gpt-4o | `completion('azure/<your deployment name>', messages)` |
|
||||||
|
|
||||||
#### Usage
|
#### Usage
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -3,53 +3,155 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Azure AI Studio
|
# Azure AI Studio
|
||||||
|
|
||||||
**Ensure the following:**
|
LiteLLM supports all models on Azure AI Studio
|
||||||
1. The API Base passed ends in the `/v1/` prefix
|
|
||||||
example:
|
|
||||||
```python
|
|
||||||
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
|
|
||||||
```
|
|
||||||
|
|
||||||
2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
### ENV VAR
|
||||||
```python
|
```python
|
||||||
import litellm
|
import os
|
||||||
response = litellm.completion(
|
os.environ["AZURE_API_API_KEY"] = ""
|
||||||
model="azure/command-r-plus",
|
os.environ["AZURE_AI_API_BASE"] = ""
|
||||||
api_base="<your-deployment-base>/v1/"
|
```
|
||||||
api_key="eskk******"
|
|
||||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
### Example Call
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["AZURE_API_API_KEY"] = "azure ai key"
|
||||||
|
os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/
|
||||||
|
|
||||||
|
# predibase llama-3 call
|
||||||
|
response = completion(
|
||||||
|
model="azure_ai/command-r-plus",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="proxy" label="PROXY">
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
## Sample Usage - LiteLLM Proxy
|
|
||||||
|
|
||||||
1. Add models to your config.yaml
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: mistral
|
|
||||||
litellm_params:
|
|
||||||
model: azure/mistral-large-latest
|
|
||||||
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
|
|
||||||
api_key: JGbKodRcTp****
|
|
||||||
- model_name: command-r-plus
|
- model_name: command-r-plus
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/command-r-plus
|
model: azure_ai/command-r-plus
|
||||||
api_key: os.environ/AZURE_COHERE_API_KEY
|
api_key: os.environ/AZURE_AI_API_KEY
|
||||||
api_base: os.environ/AZURE_COHERE_API_BASE
|
api_base: os.environ/AZURE_AI_API_BASE
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="command-r-plus",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Be a good human!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know about earth?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "command-r-plus",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Be a good human!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know about earth?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Passing additional params - max_tokens, temperature
|
||||||
|
See all litellm.completion supported params [here](../completion/input.md#translated-openai-params)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# !pip install litellm
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["AZURE_AI_API_KEY"] = "azure ai api key"
|
||||||
|
os.environ["AZURE_AI_API_BASE"] = "azure ai api base"
|
||||||
|
|
||||||
|
# command r plus call
|
||||||
|
response = completion(
|
||||||
|
model="azure_ai/command-r-plus",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0.5
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**proxy**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: command-r-plus
|
||||||
|
litellm_params:
|
||||||
|
model: azure_ai/command-r-plus
|
||||||
|
api_key: os.environ/AZURE_AI_API_KEY
|
||||||
|
api_base: os.environ/AZURE_AI_API_BASE
|
||||||
|
max_tokens: 20
|
||||||
|
temperature: 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
2. Start the proxy
|
2. Start the proxy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -103,9 +205,6 @@ response = litellm.completion(
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
</Tabs>
|
|
||||||
|
|
||||||
## Function Calling
|
## Function Calling
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
@ -115,8 +214,8 @@ response = litellm.completion(
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
|
||||||
# set env
|
# set env
|
||||||
os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key"
|
os.environ["AZURE_AI_API_KEY"] = "your-api-key"
|
||||||
os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base"
|
os.environ["AZURE_AI_API_BASE"] = "your-api-base"
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -141,9 +240,7 @@ tools = [
|
||||||
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
|
||||||
response = completion(
|
response = completion(
|
||||||
model="azure/mistral-large-latest",
|
model="azure_ai/mistral-large-latest",
|
||||||
api_base=os.getenv("AZURE_MISTRAL_API_BASE")
|
|
||||||
api_key=os.getenv("AZURE_MISTRAL_API_KEY")
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
|
@ -206,10 +303,12 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
|
LiteLLM supports **ALL** azure ai models. Here's a few examples:
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
|
||||||
| Cohere ommand-r | `completion(model="azure/command-r", messages)` |
|
| Cohere command-r | `completion(model="azure/command-r", messages)` |
|
||||||
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
|
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -144,16 +144,135 @@ print(response)
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## Set temperature, top p, etc.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||||
|
os.environ["AWS_REGION_NAME"] = ""
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=1
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
**Set on yaml**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bedrock-claude-v1
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/anthropic.claude-instant-v1
|
||||||
|
temperature: <your-temp>
|
||||||
|
top_p: <your-top-p>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Set on request**
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=1
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Pass provider-specific params
|
||||||
|
|
||||||
|
If you pass a non-openai param to litellm, we'll assume it's provider-specific and send it as a kwarg in the request body. [See more](../completion/input.md#provider-specific-params)
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||||
|
os.environ["AWS_REGION_NAME"] = ""
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
top_k=1 # 👈 PROVIDER-SPECIFIC PARAM
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
**Set on yaml**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bedrock-claude-v1
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/anthropic.claude-instant-v1
|
||||||
|
top_k: 1 # 👈 PROVIDER-SPECIFIC PARAM
|
||||||
|
```
|
||||||
|
|
||||||
|
**Set on request**
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=0.7,
|
||||||
|
extra_body={
|
||||||
|
top_k=1 # 👈 PROVIDER-SPECIFIC PARAM
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Usage - Function Calling
|
## Usage - Function Calling
|
||||||
|
|
||||||
:::info
|
LiteLLM uses Bedrock's Converse API for making tool calls
|
||||||
|
|
||||||
Claude returns it's output as an XML Tree. [Here is how we translate it](https://github.com/BerriAI/litellm/blob/49642a5b00a53b1babc1a753426a8afcac85dbbe/litellm/llms/prompt_templates/factory.py#L734).
|
|
||||||
|
|
||||||
You can see the raw response via `response._hidden_params["original_response"]`.
|
|
||||||
|
|
||||||
Claude hallucinates, e.g. returning the list param `value` as `<value>\n<item>apple</item>\n<item>banana</item>\n</value>` or `<value>\n<list>\n<item>apple</item>\n<item>banana</item>\n</list>\n</value>`.
|
|
||||||
:::
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -361,47 +480,6 @@ response = completion(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Passing an external BedrockRuntime.Client as a parameter - Completion()
|
|
||||||
Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth.
|
|
||||||
|
|
||||||
Create a client from session credentials:
|
|
||||||
```python
|
|
||||||
import boto3
|
|
||||||
from litellm import completion
|
|
||||||
|
|
||||||
bedrock = boto3.client(
|
|
||||||
service_name="bedrock-runtime",
|
|
||||||
region_name="us-east-1",
|
|
||||||
aws_access_key_id="",
|
|
||||||
aws_secret_access_key="",
|
|
||||||
aws_session_token="",
|
|
||||||
)
|
|
||||||
|
|
||||||
response = completion(
|
|
||||||
model="bedrock/anthropic.claude-instant-v1",
|
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
|
||||||
aws_bedrock_client=bedrock,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Create a client from AWS profile in `~/.aws/config`:
|
|
||||||
```python
|
|
||||||
import boto3
|
|
||||||
from litellm import completion
|
|
||||||
|
|
||||||
dev_session = boto3.Session(profile_name="dev-profile")
|
|
||||||
bedrock = dev_session.client(
|
|
||||||
service_name="bedrock-runtime",
|
|
||||||
region_name="us-east-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
response = completion(
|
|
||||||
model="bedrock/anthropic.claude-instant-v1",
|
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
|
||||||
aws_bedrock_client=bedrock,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### SSO Login (AWS Profile)
|
### SSO Login (AWS Profile)
|
||||||
- Set `AWS_PROFILE` environment variable
|
- Set `AWS_PROFILE` environment variable
|
||||||
- Make bedrock completion call
|
- Make bedrock completion call
|
||||||
|
@ -464,6 +542,56 @@ response = completion(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Passing an external BedrockRuntime.Client as a parameter - Completion()
|
||||||
|
|
||||||
|
:::warning
|
||||||
|
|
||||||
|
This is a deprecated flow. Boto3 is not async. And boto3.client does not let us make the http call through httpx. Pass in your aws params through the method above 👆. [See Auth Code](https://github.com/BerriAI/litellm/blob/55a20c7cce99a93d36a82bf3ae90ba3baf9a7f89/litellm/llms/bedrock_httpx.py#L284) [Add new auth flow](https://github.com/BerriAI/litellm/issues)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth.
|
||||||
|
|
||||||
|
Create a client from session credentials:
|
||||||
|
```python
|
||||||
|
import boto3
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
bedrock = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name="us-east-1",
|
||||||
|
aws_access_key_id="",
|
||||||
|
aws_secret_access_key="",
|
||||||
|
aws_session_token="",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
aws_bedrock_client=bedrock,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a client from AWS profile in `~/.aws/config`:
|
||||||
|
```python
|
||||||
|
import boto3
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
dev_session = boto3.Session(profile_name="dev-profile")
|
||||||
|
bedrock = dev_session.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name="us-east-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
aws_bedrock_client=bedrock,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Provisioned throughput models
|
## Provisioned throughput models
|
||||||
To use provisioned throughput Bedrock models pass
|
To use provisioned throughput Bedrock models pass
|
||||||
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
|
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
# 🆕 Clarifai
|
# Clarifai
|
||||||
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
|
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
|
||||||
|
|
||||||
|
:::warning
|
||||||
|
|
||||||
|
Streaming is not yet supported on using clarifai and litellm. Tracking support here: https://github.com/BerriAI/litellm/issues/4162
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
## Pre-Requisites
|
## Pre-Requisites
|
||||||
|
|
||||||
`pip install clarifai`
|
|
||||||
|
|
||||||
`pip install litellm`
|
`pip install litellm`
|
||||||
|
|
||||||
## Required Environment Variables
|
## Required Environment Variables
|
||||||
|
@ -12,6 +15,7 @@ To obtain your Clarifai Personal access token follow this [link](https://docs.cl
|
||||||
|
|
||||||
```python
|
```python
|
||||||
os.environ["CLARIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT
|
os.environ["CLARIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
@ -68,7 +72,7 @@ Example Usage - Note: liteLLM supports all models deployed on Clarifai
|
||||||
| clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`|
|
| clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`|
|
||||||
| clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` |
|
| clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` |
|
||||||
|
|
||||||
## Mistal LLMs
|
## Mistral LLMs
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|---------------------------------------------|------------------------------------------------------------------------|
|
|---------------------------------------------|------------------------------------------------------------------------|
|
||||||
| clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` |
|
| clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` |
|
||||||
|
|
255
docs/my-website/docs/providers/codestral.md
Normal file
|
@ -0,0 +1,255 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Codestral API [Mistral AI]
|
||||||
|
|
||||||
|
Codestral is available in select code-completion plugins but can also be queried directly. See the documentation for more details.
|
||||||
|
|
||||||
|
## API Key
|
||||||
|
```python
|
||||||
|
# env variable
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
```
|
||||||
|
|
||||||
|
## FIM / Completions
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Official Mistral API Docs: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="no-streaming" label="No Streaming">
|
||||||
|
|
||||||
|
#### Sample Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="text-completion-codestral/codestral-2405",
|
||||||
|
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
|
||||||
|
suffix="return True", # optional
|
||||||
|
temperature=0, # optional
|
||||||
|
top_p=1, # optional
|
||||||
|
max_tokens=10, # optional
|
||||||
|
min_tokens=10, # optional
|
||||||
|
seed=10, # optional
|
||||||
|
stop=["return"], # optional
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "b41e0df599f94bc1a46ea9fcdbc2aabe",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1589478378,
|
||||||
|
"model": "codestral-latest",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "\n assert is_odd(1)\n assert",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "length"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
"completion_tokens": 7,
|
||||||
|
"total_tokens": 12
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="stream" label="Streaming">
|
||||||
|
|
||||||
|
#### Sample Usage - Streaming
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="text-completion-codestral/codestral-2405",
|
||||||
|
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
|
||||||
|
suffix="return True", # optional
|
||||||
|
temperature=0, # optional
|
||||||
|
top_p=1, # optional
|
||||||
|
stream=True,
|
||||||
|
seed=10, # optional
|
||||||
|
stop=["return"], # optional
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "726025d3e2d645d09d475bb0d29e3640",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1718659669,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "codestral-2405",
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Supported Models
|
||||||
|
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|----------------|--------------------------------------------------------------|
|
||||||
|
| Codestral Latest | `completion(model="text-completion-codestral/codestral-latest", messages)` |
|
||||||
|
| Codestral 2405 | `completion(model="text-completion-codestral/codestral-2405", messages)`|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Chat Completions
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Official Mistral API Docs: https://docs.mistral.ai/api/#operation/createChatCompletion
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="no-streaming" label="No Streaming">
|
||||||
|
|
||||||
|
#### Sample Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="codestral/codestral-latest",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey, how's it going?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=0.0, # optional
|
||||||
|
top_p=1, # optional
|
||||||
|
max_tokens=10, # optional
|
||||||
|
safe_prompt=False, # optional
|
||||||
|
seed=12, # optional
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": "codestral/codestral-latest",
|
||||||
|
"system_fingerprint": None,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "\n\nHello there, how may I assist you today?",
|
||||||
|
},
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"completion_tokens": 12,
|
||||||
|
"total_tokens": 21
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="stream" label="Streaming">
|
||||||
|
|
||||||
|
#### Sample Usage - Streaming
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="codestral/codestral-latest",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey, how's it going?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=True, # optional
|
||||||
|
temperature=0.0, # optional
|
||||||
|
top_p=1, # optional
|
||||||
|
max_tokens=10, # optional
|
||||||
|
safe_prompt=False, # optional
|
||||||
|
seed=12, # optional
|
||||||
|
)
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id":"chatcmpl-123",
|
||||||
|
"object":"chat.completion.chunk",
|
||||||
|
"created":1694268190,
|
||||||
|
"model": "codestral/codestral-latest",
|
||||||
|
"system_fingerprint": None,
|
||||||
|
"choices":[
|
||||||
|
{
|
||||||
|
"index":0,
|
||||||
|
"delta":{"role":"assistant","content":"gm"},
|
||||||
|
"logprobs":null,
|
||||||
|
" finish_reason":null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Supported Models
|
||||||
|
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|----------------|--------------------------------------------------------------|
|
||||||
|
| Codestral Latest | `completion(model="codestral/codestral-latest", messages)` |
|
||||||
|
| Codestral 2405 | `completion(model="codestral/codestral-2405", messages)`|
|
|
@ -125,11 +125,12 @@ See all litellm.completion supported params [here](../completion/input.md#transl
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
import os
|
import os
|
||||||
## set ENV variables
|
## set ENV variables
|
||||||
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
os.environ["DATABRICKS_API_KEY"] = "databricks key"
|
||||||
|
os.environ["DATABRICKS_API_BASE"] = "databricks api base"
|
||||||
|
|
||||||
# predibae llama-3 call
|
# databricks dbrx call
|
||||||
response = completion(
|
response = completion(
|
||||||
model="predibase/llama3-8b-instruct",
|
model="databricks/databricks-dbrx-instruct",
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
temperature=0.5
|
temperature=0.5
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
# DeepInfra
|
# DeepInfra
|
||||||
https://deepinfra.com/
|
https://deepinfra.com/
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
**We support ALL DeepInfra models, just set `model=deepinfra/<any-model-on-deepinfra>` as a prefix when sending litellm requests**
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
## API Key
|
## API Key
|
||||||
```python
|
```python
|
||||||
# env variable
|
# env variable
|
||||||
|
@ -38,13 +45,11 @@ for chunk in response:
|
||||||
## Chat Models
|
## Chat Models
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|------------------|--------------------------------------|
|
|------------------|--------------------------------------|
|
||||||
|
| meta-llama/Meta-Llama-3-8B-Instruct | `completion(model="deepinfra/meta-llama/Meta-Llama-3-8B-Instruct", messages)` |
|
||||||
|
| meta-llama/Meta-Llama-3-70B-Instruct | `completion(model="deepinfra/meta-llama/Meta-Llama-3-70B-Instruct", messages)` |
|
||||||
| meta-llama/Llama-2-70b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-70b-chat-hf", messages)` |
|
| meta-llama/Llama-2-70b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-70b-chat-hf", messages)` |
|
||||||
| meta-llama/Llama-2-7b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-7b-chat-hf", messages)` |
|
| meta-llama/Llama-2-7b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-7b-chat-hf", messages)` |
|
||||||
| meta-llama/Llama-2-13b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-13b-chat-hf", messages)` |
|
| meta-llama/Llama-2-13b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-13b-chat-hf", messages)` |
|
||||||
| codellama/CodeLlama-34b-Instruct-hf | `completion(model="deepinfra/codellama/CodeLlama-34b-Instruct-hf", messages)` |
|
| codellama/CodeLlama-34b-Instruct-hf | `completion(model="deepinfra/codellama/CodeLlama-34b-Instruct-hf", messages)` |
|
||||||
| mistralai/Mistral-7B-Instruct-v0.1 | `completion(model="deepinfra/mistralai/Mistral-7B-Instruct-v0.1", messages)` |
|
| mistralai/Mistral-7B-Instruct-v0.1 | `completion(model="deepinfra/mistralai/Mistral-7B-Instruct-v0.1", messages)` |
|
||||||
| jondurbin/airoboros-l2-70b-gpt4-1.4.1 | `completion(model="deepinfra/jondurbin/airoboros-l2-70b-gpt4-1.4.1", messages)` |
|
| jondurbin/airoboros-l2-70b-gpt4-1.4.1 | `completion(model="deepinfra/jondurbin/airoboros-l2-70b-gpt4-1.4.1", messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,52 @@ response = completion(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Tool Calling
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
# set env
|
||||||
|
os.environ["GEMINI_API_KEY"] = ".."
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
# Add any assertions, here to check response args
|
||||||
|
print(response)
|
||||||
|
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||||
|
assert isinstance(
|
||||||
|
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
# Gemini-Pro-Vision
|
# Gemini-Pro-Vision
|
||||||
LiteLLM Supports the following image types passed in `url`
|
LiteLLM Supports the following image types passed in `url`
|
||||||
- Images with direct links - https://storage.googleapis.com/github-repo/img/gemini/intro/landmark3.jpg
|
- Images with direct links - https://storage.googleapis.com/github-repo/img/gemini/intro/landmark3.jpg
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
# Groq
|
# Groq
|
||||||
https://groq.com/
|
https://groq.com/
|
||||||
|
|
||||||
**We support ALL Groq models, just set `groq/` as a prefix when sending completion requests**
|
:::tip
|
||||||
|
|
||||||
|
**We support ALL Groq models, just set `model=groq/<any-model-on-groq>` as a prefix when sending litellm requests**
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
## API Key
|
## API Key
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# OpenAI (Text Completion)
|
# OpenAI (Text Completion)
|
||||||
|
|
||||||
LiteLLM supports OpenAI text completion models
|
LiteLLM supports OpenAI text completion models
|
||||||
|
|
|
@ -208,7 +208,7 @@ print(response)
|
||||||
|
|
||||||
Instead of using the `custom_llm_provider` arg to specify which provider you're using (e.g. together ai), you can just pass the provider name as part of the model name, and LiteLLM will parse it out.
|
Instead of using the `custom_llm_provider` arg to specify which provider you're using (e.g. together ai), you can just pass the provider name as part of the model name, and LiteLLM will parse it out.
|
||||||
|
|
||||||
Expected format: <custom_llm_provider>/<model_name>
|
Expected format: `<custom_llm_provider>/<model_name>`
|
||||||
|
|
||||||
e.g. completion(model="together_ai/togethercomputer/Llama-2-7B-32K-Instruct", ...)
|
e.g. completion(model="together_ai/togethercomputer/Llama-2-7B-32K-Instruct", ...)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,152 @@ import TabItem from '@theme/TabItem';
|
||||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
## 🆕 `vertex_ai_beta/` route
|
||||||
|
|
||||||
|
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import json
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai_beta/gemini-pro",
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
vertex_credentials=vertex_credentials_json
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### **System Message**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import json
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai_beta/gemini-pro",
|
||||||
|
messages=[{"content": "You are a good bot.","role": "system"}, {"content": "Hello, how are you?","role": "user"}],
|
||||||
|
vertex_credentials=vertex_credentials_json
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### **Function Calling**
|
||||||
|
|
||||||
|
Force Gemini to make tool calls with `tool_choice="required"`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import json
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||||
|
},
|
||||||
|
# User asks for their name and weather in San Francisco
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, what is your name and can you tell me the weather?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "vertex_ai_beta/gemini-1.5-pro-preview-0514"),
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"tool_choice": "required",
|
||||||
|
"vertex_credentials": vertex_credentials_json
|
||||||
|
}
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
print(completion(**data))
|
||||||
|
```
|
||||||
|
|
||||||
|
### **JSON Schema**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
## GET CREDENTIALS
|
||||||
|
file_path = 'path/to/vertex_ai_service_account.json'
|
||||||
|
|
||||||
|
# Load the JSON file
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
vertex_credentials = json.load(file)
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": """
|
||||||
|
List 5 popular cookie recipes.
|
||||||
|
|
||||||
|
Using this JSON schema:
|
||||||
|
|
||||||
|
Recipe = {"recipe_name": str}
|
||||||
|
|
||||||
|
Return a `list[Recipe]`
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
completion(model="vertex_ai_beta/gemini-1.5-flash-preview-0514", messages=messages, response_format={ "type": "json_object" })
|
||||||
|
```
|
||||||
|
|
||||||
## Pre-requisites
|
## Pre-requisites
|
||||||
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
||||||
* Authentication:
|
* Authentication:
|
||||||
|
@ -140,7 +286,7 @@ In certain use-cases you may need to make calls to the models and pass [safety s
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = completion(
|
response = completion(
|
||||||
model="gemini/gemini-pro",
|
model="vertex_ai/gemini-pro",
|
||||||
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
||||||
safety_settings=[
|
safety_settings=[
|
||||||
{
|
{
|
||||||
|
@ -363,8 +509,8 @@ response = completion(
|
||||||
## Gemini 1.5 Pro (and Vision)
|
## Gemini 1.5 Pro (and Vision)
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|------------------|--------------------------------------|
|
|------------------|--------------------------------------|
|
||||||
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
|
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-1.5-pro', messages)` |
|
||||||
| gemini-1.5-flash-preview-0514 | `completion('gemini-1.5-flash-preview-0514', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
|
| gemini-1.5-flash-preview-0514 | `completion('gemini-1.5-flash-preview-0514', messages)`, `completion('vertex_ai/gemini-1.5-flash-preview-0514', messages)` |
|
||||||
| gemini-1.5-pro-preview-0514 | `completion('gemini-1.5-pro-preview-0514', messages)`, `completion('vertex_ai/gemini-1.5-pro-preview-0514', messages)` |
|
| gemini-1.5-pro-preview-0514 | `completion('gemini-1.5-pro-preview-0514', messages)`, `completion('vertex_ai/gemini-1.5-pro-preview-0514', messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -449,6 +595,54 @@ print(response)
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## Usage - Function Calling
|
||||||
|
|
||||||
|
LiteLLM supports Function Calling for Vertex AI gemini models.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
# set env
|
||||||
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ".."
|
||||||
|
os.environ["VERTEX_AI_PROJECT"] = ".."
|
||||||
|
os.environ["VERTEX_AI_LOCATION"] = ".."
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/gemini-pro-vision",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
# Add any assertions, here to check response args
|
||||||
|
print(response)
|
||||||
|
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||||
|
assert isinstance(
|
||||||
|
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Chat Models
|
## Chat Models
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|
@ -500,6 +694,8 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| text-embedding-004 | `embedding(model="vertex_ai/text-embedding-004", input)` |
|
||||||
|
| text-multilingual-embedding-002 | `embedding(model="vertex_ai/text-multilingual-embedding-002", input)` |
|
||||||
| textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` |
|
| textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` |
|
||||||
| textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` |
|
| textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` |
|
||||||
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
|
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
|
||||||
|
@ -508,6 +704,29 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
||||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||||
|
|
||||||
|
### Advanced Use `task_type` and `title` (Vertex Specific Params)
|
||||||
|
|
||||||
|
👉 `task_type` and `title` are vertex specific params
|
||||||
|
|
||||||
|
LiteLLM Supported Vertex Specific Params
|
||||||
|
|
||||||
|
```python
|
||||||
|
auto_truncate: Optional[bool] = None
|
||||||
|
task_type: Optional[Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]] = None
|
||||||
|
title: Optional[str] = None # The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Usage with LiteLLM**
|
||||||
|
```python
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="vertex_ai/text-embedding-004",
|
||||||
|
input=["good morning from litellm", "gm"]
|
||||||
|
task_type = "RETRIEVAL_DOCUMENT",
|
||||||
|
dimensions=1,
|
||||||
|
auto_truncate=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Image Generation Models
|
## Image Generation Models
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
|
@ -607,6 +826,3 @@ s/o @[Darien Kindlund](https://www.linkedin.com/in/kindlund/) for this tutorial
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# 🚨 Alerting / Webhooks
|
# 🚨 Alerting / Webhooks
|
||||||
|
|
||||||
Get alerts for:
|
Get alerts for:
|
||||||
|
@ -15,6 +17,11 @@ Get alerts for:
|
||||||
- **Spend** Weekly & Monthly spend per Team, Tag
|
- **Spend** Weekly & Monthly spend per Team, Tag
|
||||||
|
|
||||||
|
|
||||||
|
Works across:
|
||||||
|
- [Slack](#quick-start)
|
||||||
|
- [Discord](#advanced---using-discord-webhooks)
|
||||||
|
- [Microsoft Teams](#advanced---using-ms-teams-webhooks)
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
Set up a slack alert channel to receive alerts from proxy.
|
Set up a slack alert channel to receive alerts from proxy.
|
||||||
|
@ -25,41 +32,33 @@ Get a slack webhook url from https://api.slack.com/messaging/webhooks
|
||||||
|
|
||||||
You can also use Discord Webhooks, see [here](#using-discord-webhooks)
|
You can also use Discord Webhooks, see [here](#using-discord-webhooks)
|
||||||
|
|
||||||
### Step 2: Update config.yaml
|
|
||||||
|
|
||||||
- Set `SLACK_WEBHOOK_URL` in your proxy env to enable Slack alerts.
|
Set `SLACK_WEBHOOK_URL` in your proxy env to enable Slack alerts.
|
||||||
- Just for testing purposes, let's save a bad key to our proxy.
|
|
||||||
|
```bash
|
||||||
|
export SLACK_WEBHOOK_URL="https://hooks.slack.com/services/<>/<>/<>"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Setup Proxy
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
|
||||||
model_name: "azure-model"
|
|
||||||
litellm_params:
|
|
||||||
model: "azure/gpt-35-turbo"
|
|
||||||
api_key: "my-bad-key" # 👈 bad key
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
||||||
|
|
||||||
environment_variables:
|
|
||||||
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
|
|
||||||
SLACK_DAILY_REPORT_FREQUENCY: "86400" # 24 hours; Optional: defaults to 12 hours
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Start proxy
|
||||||
### Step 3: Start proxy
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Testing Alerting is Setup Correctly
|
|
||||||
|
|
||||||
Make a GET request to `/health/services`, expect to see a test slack alert in your provided webhook slack channel
|
### Step 3: Test it!
|
||||||
|
|
||||||
```shell
|
|
||||||
curl -X GET 'http://localhost:4000/health/services?service=slack' \
|
```bash
|
||||||
-H 'Authorization: Bearer sk-1234'
|
curl -X GET 'http://0.0.0.0:4000/health/services?service=slack' \
|
||||||
|
-H 'Authorization: Bearer sk-1234'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced - Redacting Messages from Alerts
|
## Advanced - Redacting Messages from Alerts
|
||||||
|
@ -77,7 +76,34 @@ litellm_settings:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced - Add Metadata to alerts
|
||||||
|
|
||||||
|
Add alerting metadata to proxy calls for debugging.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [],
|
||||||
|
extra_body={
|
||||||
|
"metadata": {
|
||||||
|
"alerting_metadata": {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
<Image img={require('../../img/alerting_metadata.png')}/>
|
||||||
|
|
||||||
## Advanced - Opting into specific alert types
|
## Advanced - Opting into specific alert types
|
||||||
|
|
||||||
|
@ -108,6 +134,48 @@ AlertType = Literal[
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced - Using MS Teams Webhooks
|
||||||
|
|
||||||
|
MS Teams provides a slack compatible webhook url that you can use for alerting
|
||||||
|
|
||||||
|
##### Quick Start
|
||||||
|
|
||||||
|
1. [Get a webhook url](https://learn.microsoft.com/en-us/microsoftteams/platform/webhooks-and-connectors/how-to/add-incoming-webhook?tabs=newteams%2Cdotnet#create-an-incoming-webhook) for your Microsoft Teams channel
|
||||||
|
|
||||||
|
2. Add it to your .env
|
||||||
|
|
||||||
|
```bash
|
||||||
|
SLACK_WEBHOOK_URL="https://berriai.webhook.office.com/webhookb2/...6901/IncomingWebhook/b55fa0c2a48647be8e6effedcd540266/e04b1092-4a3e-44a2-ab6b-29a0a4854d1d"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Add it to your litellm config
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
model_name: "azure-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "azure/gpt-35-turbo"
|
||||||
|
api_key: "my-bad-key" # 👈 bad key
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
alerting: ["slack"]
|
||||||
|
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Run health check!
|
||||||
|
|
||||||
|
Call the proxy `/health/services` endpoint to test if your alerting connection is correctly setup.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:4000/health/services?service=slack' \
|
||||||
|
--header 'Authorization: Bearer sk-1234'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
<Image img={require('../../img/ms_teams_alerting.png')}/>
|
||||||
|
|
||||||
## Advanced - Using Discord Webhooks
|
## Advanced - Using Discord Webhooks
|
||||||
|
|
||||||
Discord provides a slack compatible webhook url that you can use for alerting
|
Discord provides a slack compatible webhook url that you can use for alerting
|
||||||
|
@ -139,7 +207,6 @@ environment_variables:
|
||||||
SLACK_WEBHOOK_URL: "https://discord.com/api/webhooks/1240030362193760286/cTLWt5ATn1gKmcy_982rl5xmYHsrM1IWJdmCL1AyOmU9JdQXazrp8L1_PYgUtgxj8x4f/slack"
|
SLACK_WEBHOOK_URL: "https://discord.com/api/webhooks/1240030362193760286/cTLWt5ATn1gKmcy_982rl5xmYHsrM1IWJdmCL1AyOmU9JdQXazrp8L1_PYgUtgxj8x4f/slack"
|
||||||
```
|
```
|
||||||
|
|
||||||
That's it ! You're ready to go !
|
|
||||||
|
|
||||||
## Advanced - [BETA] Webhooks for Budget Alerts
|
## Advanced - [BETA] Webhooks for Budget Alerts
|
||||||
|
|
||||||
|
|
|
@ -252,6 +252,31 @@ $ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Multiple OpenAI Organizations
|
||||||
|
|
||||||
|
Add all openai models across all OpenAI organizations with just 1 model definition
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- model_name: *
|
||||||
|
litellm_params:
|
||||||
|
model: openai/*
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
organization:
|
||||||
|
- org-1
|
||||||
|
- org-2
|
||||||
|
- org-3
|
||||||
|
```
|
||||||
|
|
||||||
|
LiteLLM will automatically create separate deployments for each org.
|
||||||
|
|
||||||
|
Confirm this via
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:4000/v1/model/info' \
|
||||||
|
--header 'Authorization: Bearer ${LITELLM_KEY}' \
|
||||||
|
--data ''
|
||||||
|
```
|
||||||
|
|
||||||
## Load Balancing
|
## Load Balancing
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
|
@ -427,4 +427,23 @@ model_list:
|
||||||
|
|
||||||
## Custom Input/Output Pricing
|
## Custom Input/Output Pricing
|
||||||
|
|
||||||
👉 Head to [Custom Input/Output Pricing](https://docs.litellm.ai/docs/proxy/custom_pricing) to setup custom pricing or your models
|
👉 Head to [Custom Input/Output Pricing](https://docs.litellm.ai/docs/proxy/custom_pricing) to setup custom pricing or your models
|
||||||
|
|
||||||
|
## ✨ Custom k,v pairs
|
||||||
|
|
||||||
|
Log specific key,value pairs as part of the metadata for a spend log
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Logging specific key,value pairs in spend logs metadata is an enterprise feature. [See here](./enterprise.md#tracking-spend-with-custom-metadata)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
## ✨ Custom Tags
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Tracking spend with Custom tags is an enterprise feature. [See here](./enterprise.md#tracking-spend-for-custom-tags)
|
||||||
|
|
||||||
|
:::
|
|
@ -1,5 +1,6 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# 🐳 Docker, Deploying LiteLLM Proxy
|
# 🐳 Docker, Deploying LiteLLM Proxy
|
||||||
|
|
||||||
|
@ -26,7 +27,7 @@ docker-compose up
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
|
||||||
<TabItem value="basic" label="Basic">
|
<TabItem value="basic" label="Basic (No DB)">
|
||||||
|
|
||||||
### Step 1. CREATE config.yaml
|
### Step 1. CREATE config.yaml
|
||||||
|
|
||||||
|
@ -97,7 +98,13 @@ docker run ghcr.io/berriai/litellm:main-latest --port 8002 --num_workers 8
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
<TabItem value="terraform" label="Terraform">
|
||||||
|
|
||||||
|
s/o [Nicholas Cecere](https://www.linkedin.com/in/nicholas-cecere-24243549/) for his LiteLLM User Management Terraform
|
||||||
|
|
||||||
|
👉 [Go here for Terraform](https://github.com/ncecere/terraform-litellm-user-mgmt)
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
<TabItem value="base-image" label="use litellm as a base image">
|
<TabItem value="base-image" label="use litellm as a base image">
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -379,6 +386,7 @@ kubectl port-forward service/litellm-service 4000:4000
|
||||||
Your OpenAI proxy server is now running on `http://0.0.0.0:4000`.
|
Your OpenAI proxy server is now running on `http://0.0.0.0:4000`.
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
<TabItem value="helm-deploy" label="Helm">
|
<TabItem value="helm-deploy" label="Helm">
|
||||||
|
|
||||||
|
|
||||||
|
@ -424,7 +432,6 @@ If you need to set your litellm proxy config.yaml, you can find this in [values.
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
<TabItem value="helm-oci" label="Helm OCI Registry (GHCR)">
|
<TabItem value="helm-oci" label="Helm OCI Registry (GHCR)">
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
@ -537,7 +544,9 @@ ghcr.io/berriai/litellm-database:main-latest --config your_config.yaml
|
||||||
|
|
||||||
## Advanced Deployment Settings
|
## Advanced Deployment Settings
|
||||||
|
|
||||||
### Customization of the server root path
|
### 1. Customization of the server root path (custom Proxy base url)
|
||||||
|
|
||||||
|
💥 Use this when you want to serve LiteLLM on a custom base url path like `https://localhost:4000/api/v1`
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
||||||
|
@ -548,9 +557,29 @@ In a Kubernetes deployment, it's possible to utilize a shared DNS to host multip
|
||||||
Customize the root path to eliminate the need for employing multiple DNS configurations during deployment.
|
Customize the root path to eliminate the need for employing multiple DNS configurations during deployment.
|
||||||
|
|
||||||
👉 Set `SERVER_ROOT_PATH` in your .env and this will be set as your server root path
|
👉 Set `SERVER_ROOT_PATH` in your .env and this will be set as your server root path
|
||||||
|
```
|
||||||
|
export SERVER_ROOT_PATH="/api/v1"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 1. Run Proxy with `SERVER_ROOT_PATH` set in your env **
|
||||||
|
|
||||||
### Setting SSL Certification
|
```shell
|
||||||
|
docker run --name litellm-proxy \
|
||||||
|
-e DATABASE_URL=postgresql://<user>:<password>@<host>:<port>/<dbname> \
|
||||||
|
-e SERVER_ROOT_PATH="/api/v1" \
|
||||||
|
-p 4000:4000 \
|
||||||
|
ghcr.io/berriai/litellm-database:main-latest --config your_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
After running the proxy you can access it on `http://0.0.0.0:4000/api/v1/` (since we set `SERVER_ROOT_PATH="/api/v1"`)
|
||||||
|
|
||||||
|
**Step 2. Verify Running on correct path**
|
||||||
|
|
||||||
|
<Image img={require('../../img/custom_root_path.png')} />
|
||||||
|
|
||||||
|
**That's it**, that's all you need to run the proxy on a custom root path
|
||||||
|
|
||||||
|
### 2. Setting SSL Certification
|
||||||
|
|
||||||
Use this, If you need to set ssl certificates for your on prem litellm proxy
|
Use this, If you need to set ssl certificates for your on prem litellm proxy
|
||||||
|
|
||||||
|
@ -646,7 +675,7 @@ Once the stack is created, get the DatabaseURL of the Database resource, copy th
|
||||||
#### 3. Connect to the EC2 Instance and deploy litellm on the EC2 container
|
#### 3. Connect to the EC2 Instance and deploy litellm on the EC2 container
|
||||||
From the EC2 console, connect to the instance created by the stack (e.g., using SSH).
|
From the EC2 console, connect to the instance created by the stack (e.g., using SSH).
|
||||||
|
|
||||||
Run the following command, replacing <database_url> with the value you copied in step 2
|
Run the following command, replacing `<database_url>` with the value you copied in step 2
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
docker run --name litellm-proxy \
|
docker run --name litellm-proxy \
|
||||||
|
|
|
@ -5,6 +5,7 @@ import Image from '@theme/IdealImage';
|
||||||
Send an Email to your users when:
|
Send an Email to your users when:
|
||||||
- A Proxy API Key is created for them
|
- A Proxy API Key is created for them
|
||||||
- Their API Key crosses it's Budget
|
- Their API Key crosses it's Budget
|
||||||
|
- All Team members of a LiteLLM Team -> when the team crosses it's budget
|
||||||
|
|
||||||
<Image img={require('../../img/email_notifs.png')} style={{ width: '500px' }}/>
|
<Image img={require('../../img/email_notifs.png')} style={{ width: '500px' }}/>
|
||||||
|
|
||||||
|
|
|
@ -205,6 +205,146 @@ curl -X GET "http://0.0.0.0:4000/spend/tags" \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Tracking Spend with custom metadata
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
|
||||||
|
- Virtual Keys & a database should be set up, see [virtual keys](https://docs.litellm.ai/docs/proxy/virtual_keys)
|
||||||
|
|
||||||
|
#### Usage - /chat/completions requests with special spend logs metadata
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
Set `extra_body={"metadata": { }}` to `metadata` you want to pass
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={
|
||||||
|
"metadata": {
|
||||||
|
"spend_logs_metadata": {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="Curl" label="Curl Request">
|
||||||
|
|
||||||
|
Pass `metadata` as part of the request body
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"spend_logs_metadata": {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="langchain" label="Langchain">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
from langchain.schema import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
openai_api_base="http://0.0.0.0:4000",
|
||||||
|
model = "gpt-3.5-turbo",
|
||||||
|
temperature=0.1,
|
||||||
|
extra_body={
|
||||||
|
"metadata": {
|
||||||
|
"spend_logs_metadata": {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a helpful assistant that im using to make a test request to."
|
||||||
|
),
|
||||||
|
HumanMessage(
|
||||||
|
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
response = chat(messages)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
#### Viewing Spend w/ custom metadata
|
||||||
|
|
||||||
|
#### `/spend/logs` Request Format
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X GET "http://0.0.0.0:4000/spend/logs?request_id=<your-call-id" \ # e.g.: chatcmpl-9ZKMURhVYSi9D6r6PJ9vLcayIK0Vm
|
||||||
|
-H "Authorization: Bearer sk-1234"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `/spend/logs` Response Format
|
||||||
|
```bash
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"request_id": "chatcmpl-9ZKMURhVYSi9D6r6PJ9vLcayIK0Vm",
|
||||||
|
"call_type": "acompletion",
|
||||||
|
"metadata": {
|
||||||
|
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
|
||||||
|
"user_api_key_alias": null,
|
||||||
|
"spend_logs_metadata": { # 👈 LOGGED CUSTOM METADATA
|
||||||
|
"hello": "world"
|
||||||
|
},
|
||||||
|
"user_api_key_team_id": null,
|
||||||
|
"user_api_key_user_id": "116544810872468347480",
|
||||||
|
"user_api_key_team_alias": null
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Enforce Required Params for LLM Requests
|
## Enforce Required Params for LLM Requests
|
||||||
Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params.
|
Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params.
|
||||||
|
|
||||||
|
|
|
@ -606,6 +606,52 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
|
||||||
** 🎉 Expect to see this trace logged in your OTEL collector**
|
** 🎉 Expect to see this trace logged in your OTEL collector**
|
||||||
|
|
||||||
|
### Context propagation across Services `Traceparent HTTP Header`
|
||||||
|
|
||||||
|
❓ Use this when you want to **pass information about the incoming request in a distributed tracing system**
|
||||||
|
|
||||||
|
✅ Key change: Pass the **`traceparent` header** in your requests. [Read more about traceparent headers here](https://uptrace.dev/opentelemetry/opentelemetry-traceparent.html#what-is-traceparent-header)
|
||||||
|
```curl
|
||||||
|
traceparent: 00-80e1afed08e019fc1110464cfa66635c-7a085853722dc6d2-01
|
||||||
|
```
|
||||||
|
Example Usage
|
||||||
|
1. Make Request to LiteLLM Proxy with `traceparent` header
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
example_traceparent = f"00-80e1afed08e019fc1110464cfa66635c-02e80198930058d4-01"
|
||||||
|
extra_headers = {
|
||||||
|
"traceparent": example_traceparent
|
||||||
|
}
|
||||||
|
_trace_id = example_traceparent.split("-")[1]
|
||||||
|
|
||||||
|
print("EXTRA HEADERS: ", extra_headers)
|
||||||
|
print("Trace ID: ", _trace_id)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="llama3",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "this is a test request, write a short poem"}
|
||||||
|
],
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# EXTRA HEADERS: {'traceparent': '00-80e1afed08e019fc1110464cfa66635c-02e80198930058d4-01'}
|
||||||
|
# Trace ID: 80e1afed08e019fc1110464cfa66635c
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Lookup Trace ID on OTEL Logger
|
||||||
|
|
||||||
|
Search for Trace=`80e1afed08e019fc1110464cfa66635c` on your OTEL Collector
|
||||||
|
|
||||||
|
<Image img={require('../../img/otel_parent.png')} />
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Model Management
|
# Model Management
|
||||||
Add new models + Get model info without restarting proxy.
|
Add new models + Get model info without restarting proxy.
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# LiteLLM Proxy Performance
|
# LiteLLM Proxy Performance
|
||||||
|
|
||||||
### Throughput - 30% Increase
|
### Throughput - 30% Increase
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Grafana, Prometheus metrics [BETA]
|
# 📈 Prometheus metrics [BETA]
|
||||||
|
|
||||||
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
|
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
|
||||||
|
|
||||||
|
@ -54,6 +54,13 @@ http://localhost:4000/metrics
|
||||||
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
|
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
|
||||||
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
|
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
|
||||||
|
|
||||||
|
### Budget Metrics
|
||||||
|
| Metric Name | Description |
|
||||||
|
|----------------------|--------------------------------------|
|
||||||
|
| `litellm_remaining_team_budget_metric` | Remaining Budget for Team (A team created on LiteLLM) |
|
||||||
|
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)|
|
||||||
|
|
||||||
|
|
||||||
## Monitor System Health
|
## Monitor System Health
|
||||||
|
|
||||||
To monitor the health of litellm adjacent services (redis / postgres), do:
|
To monitor the health of litellm adjacent services (redis / postgres), do:
|
||||||
|
|
|
@ -155,9 +155,7 @@ response = client.chat.completions.create(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
extra_body={
|
extra_body={
|
||||||
"metadata": {
|
"fallbacks": ["gpt-3.5-turbo"]
|
||||||
"fallbacks": ["gpt-3.5-turbo"]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -180,9 +178,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
"content": "what llm are you"
|
"content": "what llm are you"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"fallbacks": ["gpt-3.5-turbo"]
|
||||||
"fallbacks": ["gpt-3.5-turbo"]
|
|
||||||
}
|
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
@ -204,9 +200,7 @@ chat = ChatOpenAI(
|
||||||
openai_api_base="http://0.0.0.0:4000",
|
openai_api_base="http://0.0.0.0:4000",
|
||||||
model="zephyr-beta",
|
model="zephyr-beta",
|
||||||
extra_body={
|
extra_body={
|
||||||
"metadata": {
|
"fallbacks": ["gpt-3.5-turbo"]
|
||||||
"fallbacks": ["gpt-3.5-turbo"]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -415,6 +409,28 @@ print(response)
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
### Content Policy Fallbacks
|
||||||
|
|
||||||
|
Fallback across providers (e.g. from Azure OpenAI to Anthropic) if you hit content policy violation errors.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo-small
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: "2023-07-01-preview"
|
||||||
|
|
||||||
|
- model_name: claude-opus
|
||||||
|
litellm_params:
|
||||||
|
model: claude-3-opus-20240229
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
content_policy_fallbacks: [{"gpt-3.5-turbo-small": ["claude-opus"]}]
|
||||||
|
```
|
||||||
|
|
||||||
### EU-Region Filtering (Pre-Call Checks)
|
### EU-Region Filtering (Pre-Call Checks)
|
||||||
|
|
||||||
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
||||||
|
|
|
@ -123,4 +123,18 @@ LiteLLM Enterprise: Enable [SSO login](./ui.md#setup-ssoauth-for-ui)
|
||||||
4. User can now create their own keys
|
4. User can now create their own keys
|
||||||
|
|
||||||
|
|
||||||
<Image img={require('../../img/ui_self_serve_create_key.png')} style={{ width: '800px', height: 'auto' }} />
|
<Image img={require('../../img/ui_self_serve_create_key.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced
|
||||||
|
### Setting custom logout URLs
|
||||||
|
|
||||||
|
Set `PROXY_LOGOUT_URL` in your .env if you want users to get redirected to a specific URL when they click logout
|
||||||
|
|
||||||
|
```
|
||||||
|
export PROXY_LOGOUT_URL="https://www.google.com"
|
||||||
|
```
|
||||||
|
|
||||||
|
<Image img={require('../../img/ui_logout.png')} style={{ width: '400px', height: 'auto' }} />
|
||||||
|
|
||||||
|
|
||||||
|
|
123
docs/my-website/docs/proxy/team_budgets.md
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# 💰 Setting Team Budgets
|
||||||
|
|
||||||
|
Track spend, set budgets for your Internal Team
|
||||||
|
|
||||||
|
## Setting Monthly Team Budgets
|
||||||
|
|
||||||
|
### 1. Create a team
|
||||||
|
- Set `max_budget=000000001` ($ value the team is allowed to spend)
|
||||||
|
- Set `budget_duration="1d"` (How frequently the budget should update)
|
||||||
|
|
||||||
|
|
||||||
|
Create a new team and set `max_budget` and `budget_duration`
|
||||||
|
```shell
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/team/new' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"team_alias": "QA Prod Bot",
|
||||||
|
"max_budget": 0.000000001,
|
||||||
|
"budget_duration": "1d"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Response
|
||||||
|
```shell
|
||||||
|
{
|
||||||
|
"team_alias": "QA Prod Bot",
|
||||||
|
"team_id": "de35b29e-6ca8-4f47-b804-2b79d07aa99a",
|
||||||
|
"max_budget": 0.0001,
|
||||||
|
"budget_duration": "1d",
|
||||||
|
"budget_reset_at": "2024-06-14T22:48:36.594000Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Possible values for `budget_duration`
|
||||||
|
|
||||||
|
| `budget_duration` | When Budget will reset |
|
||||||
|
| --- | --- |
|
||||||
|
| `budget_duration="1s"` | every 1 second |
|
||||||
|
| `budget_duration="1m"` | every 1 min |
|
||||||
|
| `budget_duration="1h"` | every 1 hour |
|
||||||
|
| `budget_duration="1d"` | every 1 day |
|
||||||
|
| `budget_duration="1mo"` | every 1 month |
|
||||||
|
|
||||||
|
|
||||||
|
### 2. Create a key for the `team`
|
||||||
|
|
||||||
|
Create a key for `team_id="de35b29e-6ca8-4f47-b804-2b79d07aa99a"` from Step 1
|
||||||
|
|
||||||
|
💡 **The Budget for Team="QA Prod Bot" budget will apply to this team**
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{"team_id": "de35b29e-6ca8-4f47-b804-2b79d07aa99a"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Response
|
||||||
|
|
||||||
|
```shell
|
||||||
|
{"team_id":"de35b29e-6ca8-4f47-b804-2b79d07aa99a", "key":"sk-5qtncoYjzRcxMM4bDRktNQ"}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Test It
|
||||||
|
|
||||||
|
Use the key from step 2 and run this Request twice
|
||||||
|
```shell
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Authorization: Bearer sk-mso-JSykEGri86KyOvgxBw' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d ' {
|
||||||
|
"model": "llama3",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hi"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
On the 2nd response - expect to see the following exception
|
||||||
|
|
||||||
|
```shell
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": "Budget has been exceeded! Current cost: 3.5e-06, Max budget: 1e-09",
|
||||||
|
"type": "auth_error",
|
||||||
|
"param": null,
|
||||||
|
"code": 400
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced
|
||||||
|
|
||||||
|
### Prometheus metrics for `remaining_budget`
|
||||||
|
|
||||||
|
[More info about Prometheus metrics here](https://docs.litellm.ai/docs/proxy/prometheus)
|
||||||
|
|
||||||
|
You'll need the following in your proxy config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
success_callback: ["prometheus"]
|
||||||
|
failure_callback: ["prometheus"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Expect to see this metric on prometheus to track the Remaining Budget for the team
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e-6ca8-4f47-b804-2b79d07aa99a"} 9.699999999999992e-06
|
||||||
|
```
|
||||||
|
|
||||||
|
|
|
@ -77,6 +77,28 @@ litellm_settings:
|
||||||
|
|
||||||
#### Step 2: Setup Oauth Client
|
#### Step 2: Setup Oauth Client
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
<TabItem value="okta" label="Okta SSO">
|
||||||
|
|
||||||
|
1. Add Okta credentials to your .env
|
||||||
|
|
||||||
|
```bash
|
||||||
|
GENERIC_CLIENT_ID = "<your-okta-client-id>"
|
||||||
|
GENERIC_CLIENT_SECRET = "<your-okta-client-secret>"
|
||||||
|
GENERIC_AUTHORIZATION_ENDPOINT = "<your-okta-domain>/authorize" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/authorize
|
||||||
|
GENERIC_TOKEN_ENDPOINT = "<your-okta-domain>/token" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/oauth/token
|
||||||
|
GENERIC_USERINFO_ENDPOINT = "<your-okta-domain>/userinfo" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/userinfo
|
||||||
|
```
|
||||||
|
|
||||||
|
You can get your domain specific auth/token/userinfo endpoints at `<YOUR-OKTA-DOMAIN>/.well-known/openid-configuration`
|
||||||
|
|
||||||
|
2. Add proxy url as callback_url on Okta
|
||||||
|
|
||||||
|
On Okta, add the 'callback_url' as `<proxy_base_url>/sso/callback`
|
||||||
|
|
||||||
|
|
||||||
|
<Image img={require('../../img/okta_callback_url.png')} />
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
<TabItem value="google" label="Google SSO">
|
<TabItem value="google" label="Google SSO">
|
||||||
|
|
||||||
- Create a new Oauth 2.0 Client on https://console.cloud.google.com/
|
- Create a new Oauth 2.0 Client on https://console.cloud.google.com/
|
||||||
|
@ -115,7 +137,6 @@ MICROSOFT_TENANT="5a39737
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
<TabItem value="Generic" label="Generic SSO Provider">
|
<TabItem value="Generic" label="Generic SSO Provider">
|
||||||
|
|
||||||
A generic OAuth client that can be used to quickly create support for any OAuth provider with close to no code
|
A generic OAuth client that can be used to quickly create support for any OAuth provider with close to no code
|
||||||
|
|
|
@ -63,7 +63,7 @@ You can:
|
||||||
- Add budgets to Teams
|
- Add budgets to Teams
|
||||||
|
|
||||||
|
|
||||||
#### **Add budgets to users**
|
#### **Add budgets to teams**
|
||||||
```shell
|
```shell
|
||||||
curl --location 'http://localhost:4000/team/new' \
|
curl --location 'http://localhost:4000/team/new' \
|
||||||
--header 'Authorization: Bearer <your-master-key>' \
|
--header 'Authorization: Bearer <your-master-key>' \
|
||||||
|
@ -102,6 +102,22 @@ curl --location 'http://localhost:4000/team/new' \
|
||||||
"budget_reset_at": null
|
"budget_reset_at": null
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### **Add budget duration to teams**
|
||||||
|
|
||||||
|
`budget_duration`: Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||||
|
|
||||||
|
```
|
||||||
|
curl 'http://0.0.0.0:4000/team/new' \
|
||||||
|
--header 'Authorization: Bearer <your-master-key>' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data-raw '{
|
||||||
|
"team_alias": "my-new-team_4",
|
||||||
|
"members_with_roles": [{"role": "admin", "user_id": "5c4a0aa3-a1e1-43dc-bd87-3c2da8382a3a"}],
|
||||||
|
"budget_duration": 10s,
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="per-team-member" label="For Team Members">
|
<TabItem value="per-team-member" label="For Team Members">
|
||||||
|
|
||||||
|
@ -397,6 +413,52 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
### Reset Budgets
|
||||||
|
|
||||||
|
Reset budgets across keys/internal users/teams/customers
|
||||||
|
|
||||||
|
`budget_duration`: Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="users" label="Internal Users">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl 'http://0.0.0.0:4000/user/new' \
|
||||||
|
--header 'Authorization: Bearer <your-master-key>' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data-raw '{
|
||||||
|
"max_budget": 10,
|
||||||
|
"budget_duration": 10s, # 👈 KEY CHANGE
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="keys" label="Keys">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
--header 'Authorization: Bearer <your-master-key>' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data-raw '{
|
||||||
|
"max_budget": 10,
|
||||||
|
"budget_duration": 10s, # 👈 KEY CHANGE
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="teams" label="Teams">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl 'http://0.0.0.0:4000/team/new' \
|
||||||
|
--header 'Authorization: Bearer <your-master-key>' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data-raw '{
|
||||||
|
"max_budget": 10,
|
||||||
|
"budget_duration": 10s, # 👈 KEY CHANGE
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Set Rate Limits
|
## Set Rate Limits
|
||||||
|
|
||||||
You can set:
|
You can set:
|
||||||
|
|
|
@ -790,85 +790,205 @@ If the error is a context window exceeded error, fall back to a larger model gro
|
||||||
|
|
||||||
Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc.
|
Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc.
|
||||||
|
|
||||||
You can also set 'default_fallbacks', in case a specific model group is misconfigured / bad.
|
You can also set `default_fallbacks`, in case a specific model group is misconfigured / bad.
|
||||||
|
|
||||||
|
There are 3 types of fallbacks:
|
||||||
|
- `content_policy_fallbacks`: For litellm.ContentPolicyViolationError - LiteLLM maps content policy violation errors across providers [**See Code**](https://github.com/BerriAI/litellm/blob/89a43c872a1e3084519fb9de159bf52f5447c6c4/litellm/utils.py#L8495C27-L8495C54)
|
||||||
|
- `context_window_fallbacks`: For litellm.ContextWindowExceededErrors - LiteLLM maps context window error messages across providers [**See Code**](https://github.com/BerriAI/litellm/blob/89a43c872a1e3084519fb9de159bf52f5447c6c4/litellm/utils.py#L8469)
|
||||||
|
- `fallbacks`: For all remaining errors - e.g. litellm.RateLimitError
|
||||||
|
|
||||||
|
**Content Policy Violation Fallback**
|
||||||
|
|
||||||
|
Key change:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import Router
|
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}]
|
||||||
|
|
||||||
model_list = [
|
|
||||||
{ # list of model deployments
|
|
||||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": "bad-key",
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
|
||||||
},
|
|
||||||
"tpm": 240000,
|
|
||||||
"rpm": 1800
|
|
||||||
},
|
|
||||||
{ # list of model deployments
|
|
||||||
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": "bad-key",
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
|
||||||
},
|
|
||||||
"tpm": 240000,
|
|
||||||
"rpm": 1800
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "azure/chatgpt-functioncalling",
|
|
||||||
"api_key": "bad-key",
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
|
||||||
},
|
|
||||||
"tpm": 240000,
|
|
||||||
"rpm": 1800
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
"tpm": 1000000,
|
|
||||||
"rpm": 9000
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "gpt-3.5-turbo-16k",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
"tpm": 1000000,
|
|
||||||
"rpm": 9000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
router = Router(model_list=model_list,
|
|
||||||
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
|
||||||
default_fallbacks=["gpt-3.5-turbo-16k"],
|
|
||||||
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
|
||||||
set_verbose=True)
|
|
||||||
|
|
||||||
|
|
||||||
user_message = "Hello, whats the weather in San Francisco??"
|
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
|
||||||
|
|
||||||
# normal fallback call
|
|
||||||
response = router.completion(model="azure/gpt-3.5-turbo", messages=messages)
|
|
||||||
|
|
||||||
# context window fallback call
|
|
||||||
response = router.completion(model="azure/gpt-3.5-turbo-context-fallback", messages=messages)
|
|
||||||
|
|
||||||
print(f"response: {response}")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-2",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-2",
|
||||||
|
"api_key": "",
|
||||||
|
"mock_response": Exception("content filtering policy"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "my-fallback-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-2",
|
||||||
|
"api_key": "",
|
||||||
|
"mock_response": "This works!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}], # 👈 KEY CHANGE
|
||||||
|
# fallbacks=[..], # [OPTIONAL]
|
||||||
|
# context_window_fallbacks=[..], # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = router.completion(
|
||||||
|
model="claude-2",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
In your proxy config.yaml just add this line 👇
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}]
|
||||||
|
```
|
||||||
|
|
||||||
|
Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
**Context Window Exceeded Fallback**
|
||||||
|
|
||||||
|
Key change:
|
||||||
|
|
||||||
|
```python
|
||||||
|
context_window_fallbacks=[{"claude-2": ["my-fallback-model"]}]
|
||||||
|
```
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-2",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-2",
|
||||||
|
"api_key": "",
|
||||||
|
"mock_response": Exception("prompt is too long"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "my-fallback-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-2",
|
||||||
|
"api_key": "",
|
||||||
|
"mock_response": "This works!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
context_window_fallbacks=[{"claude-2": ["my-fallback-model"]}], # 👈 KEY CHANGE
|
||||||
|
# fallbacks=[..], # [OPTIONAL]
|
||||||
|
# content_policy_fallbacks=[..], # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = router.completion(
|
||||||
|
model="claude-2",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
In your proxy config.yaml just add this line 👇
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
context_window_fallbacks=[{"claude-2": ["my-fallback-model"]}]
|
||||||
|
```
|
||||||
|
|
||||||
|
Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
**Regular Fallbacks**
|
||||||
|
|
||||||
|
Key change:
|
||||||
|
|
||||||
|
```python
|
||||||
|
fallbacks=[{"claude-2": ["my-fallback-model"]}]
|
||||||
|
```
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-2",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-2",
|
||||||
|
"api_key": "",
|
||||||
|
"mock_response": Exception("this is a rate limit error"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "my-fallback-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-2",
|
||||||
|
"api_key": "",
|
||||||
|
"mock_response": "This works!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
fallbacks=[{"claude-2": ["my-fallback-model"]}], # 👈 KEY CHANGE
|
||||||
|
# context_window_fallbacks=[..], # [OPTIONAL]
|
||||||
|
# content_policy_fallbacks=[..], # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = router.completion(
|
||||||
|
model="claude-2",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
In your proxy config.yaml just add this line 👇
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
fallbacks=[{"claude-2": ["my-fallback-model"]}]
|
||||||
|
```
|
||||||
|
|
||||||
|
Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Caching
|
### Caching
|
||||||
|
|
||||||
In production, we recommend using a Redis cache. For quickly testing things locally, we also support simple in-memory caching.
|
In production, we recommend using a Redis cache. For quickly testing things locally, we also support simple in-memory caching.
|
||||||
|
|
|
@ -23,9 +23,13 @@ https://api.together.xyz/playground/chat?model=togethercomputer%2Fllama-2-70b-ch
|
||||||
model_name = "together_ai/togethercomputer/llama-2-70b-chat"
|
model_name = "together_ai/togethercomputer/llama-2-70b-chat"
|
||||||
response = completion(model=model_name, messages=messages)
|
response = completion(model=model_name, messages=messages)
|
||||||
print(response)
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
{'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'role': 'assistant', 'content': "\n\nI'm not able to provide real-time weather information. However, I can suggest"}}], 'created': 1691629657.9288375, 'model': 'togethercomputer/llama-2-70b-chat', 'usage': {'prompt_tokens': 9, 'completion_tokens': 17, 'total_tokens': 26}}
|
{'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'role': 'assistant', 'content': "\n\nI'm not able to provide real-time weather information. However, I can suggest"}}], 'created': 1691629657.9288375, 'model': 'togethercomputer/llama-2-70b-chat', 'usage': {'prompt_tokens': 9, 'completion_tokens': 17, 'total_tokens': 26}}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
LiteLLM handles the prompt formatting for Together AI's Llama2 models as well, converting your message to the
|
LiteLLM handles the prompt formatting for Together AI's Llama2 models as well, converting your message to the
|
||||||
|
|
|
@ -138,8 +138,8 @@ const config = {
|
||||||
title: 'Docs',
|
title: 'Docs',
|
||||||
items: [
|
items: [
|
||||||
{
|
{
|
||||||
label: 'Tutorial',
|
label: 'Getting Started',
|
||||||
to: '/docs/index',
|
to: 'https://docs.litellm.ai/docs/',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|
BIN
docs/my-website/img/alerting_metadata.png
Normal file
After Width: | Height: | Size: 207 KiB |
BIN
docs/my-website/img/custom_root_path.png
Normal file
After Width: | Height: | Size: 151 KiB |
BIN
docs/my-website/img/ms_teams_alerting.png
Normal file
After Width: | Height: | Size: 241 KiB |
BIN
docs/my-website/img/okta_callback_url.png
Normal file
After Width: | Height: | Size: 279 KiB |
BIN
docs/my-website/img/otel_parent.png
Normal file
After Width: | Height: | Size: 200 KiB |
BIN
docs/my-website/img/raw_request_log.png
Normal file
After Width: | Height: | Size: 168 KiB |
BIN
docs/my-website/img/ui_logout.png
Normal file
After Width: | Height: | Size: 27 KiB |
2308
docs/my-website/package-lock.json
generated
|
@ -23,8 +23,8 @@
|
||||||
"docusaurus": "^1.14.7",
|
"docusaurus": "^1.14.7",
|
||||||
"docusaurus-lunr-search": "^2.4.1",
|
"docusaurus-lunr-search": "^2.4.1",
|
||||||
"prism-react-renderer": "^1.3.5",
|
"prism-react-renderer": "^1.3.5",
|
||||||
"react": "^17.0.2",
|
"react": "^18.1.0",
|
||||||
"react-dom": "^17.0.2",
|
"react-dom": "^18.1.0",
|
||||||
"sharp": "^0.32.6",
|
"sharp": "^0.32.6",
|
||||||
"uuid": "^9.0.1"
|
"uuid": "^9.0.1"
|
||||||
},
|
},
|
||||||
|
|
|
@ -44,6 +44,7 @@ const sidebars = {
|
||||||
"proxy/self_serve",
|
"proxy/self_serve",
|
||||||
"proxy/users",
|
"proxy/users",
|
||||||
"proxy/customers",
|
"proxy/customers",
|
||||||
|
"proxy/team_budgets",
|
||||||
"proxy/billing",
|
"proxy/billing",
|
||||||
"proxy/user_keys",
|
"proxy/user_keys",
|
||||||
"proxy/virtual_keys",
|
"proxy/virtual_keys",
|
||||||
|
@ -54,6 +55,7 @@ const sidebars = {
|
||||||
items: ["proxy/logging", "proxy/streaming_logging"],
|
items: ["proxy/logging", "proxy/streaming_logging"],
|
||||||
},
|
},
|
||||||
"proxy/ui",
|
"proxy/ui",
|
||||||
|
"proxy/prometheus",
|
||||||
"proxy/email",
|
"proxy/email",
|
||||||
"proxy/multiple_admins",
|
"proxy/multiple_admins",
|
||||||
"proxy/team_based_routing",
|
"proxy/team_based_routing",
|
||||||
|
@ -70,7 +72,6 @@ const sidebars = {
|
||||||
"proxy/pii_masking",
|
"proxy/pii_masking",
|
||||||
"proxy/prompt_injection",
|
"proxy/prompt_injection",
|
||||||
"proxy/caching",
|
"proxy/caching",
|
||||||
"proxy/prometheus",
|
|
||||||
"proxy/call_hooks",
|
"proxy/call_hooks",
|
||||||
"proxy/rules",
|
"proxy/rules",
|
||||||
"proxy/cli",
|
"proxy/cli",
|
||||||
|
@ -133,10 +134,11 @@ const sidebars = {
|
||||||
"providers/vertex",
|
"providers/vertex",
|
||||||
"providers/palm",
|
"providers/palm",
|
||||||
"providers/gemini",
|
"providers/gemini",
|
||||||
"providers/mistral",
|
|
||||||
"providers/anthropic",
|
"providers/anthropic",
|
||||||
"providers/aws_sagemaker",
|
"providers/aws_sagemaker",
|
||||||
"providers/bedrock",
|
"providers/bedrock",
|
||||||
|
"providers/mistral",
|
||||||
|
"providers/codestral",
|
||||||
"providers/cohere",
|
"providers/cohere",
|
||||||
"providers/anyscale",
|
"providers/anyscale",
|
||||||
"providers/huggingface",
|
"providers/huggingface",
|
||||||
|
@ -170,10 +172,8 @@ const sidebars = {
|
||||||
"proxy/custom_pricing",
|
"proxy/custom_pricing",
|
||||||
"routing",
|
"routing",
|
||||||
"scheduler",
|
"scheduler",
|
||||||
"rules",
|
|
||||||
"set_keys",
|
"set_keys",
|
||||||
"budget_manager",
|
"budget_manager",
|
||||||
"contributing",
|
|
||||||
"secret",
|
"secret",
|
||||||
"completion/token_usage",
|
"completion/token_usage",
|
||||||
"load_test",
|
"load_test",
|
||||||
|
@ -181,10 +181,11 @@ const sidebars = {
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Logging & Observability",
|
label: "Logging & Observability",
|
||||||
items: [
|
items: [
|
||||||
"debugging/local_debugging",
|
|
||||||
"observability/callbacks",
|
|
||||||
"observability/custom_callback",
|
|
||||||
"observability/langfuse_integration",
|
"observability/langfuse_integration",
|
||||||
|
"observability/logfire_integration",
|
||||||
|
"debugging/local_debugging",
|
||||||
|
"observability/raw_request_response",
|
||||||
|
"observability/custom_callback",
|
||||||
"observability/sentry",
|
"observability/sentry",
|
||||||
"observability/lago",
|
"observability/lago",
|
||||||
"observability/openmeter",
|
"observability/openmeter",
|
||||||
|
@ -222,14 +223,16 @@ const sidebars = {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "LangChain, LlamaIndex Integration",
|
label: "LangChain, LlamaIndex, Instructor Integration",
|
||||||
items: ["langchain/langchain"],
|
items: ["langchain/langchain", "tutorials/instructor"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Extras",
|
label: "Extras",
|
||||||
items: [
|
items: [
|
||||||
"extras/contributing",
|
"extras/contributing",
|
||||||
|
"contributing",
|
||||||
|
"rules",
|
||||||
"proxy_server",
|
"proxy_server",
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
|
|
|
@ -93,7 +93,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
|
||||||
response.choices[0], litellm.utils.Choices
|
response.choices[0], litellm.utils.Choices
|
||||||
):
|
):
|
||||||
for word in self.banned_keywords_list:
|
for word in self.banned_keywords_list:
|
||||||
self.test_violation(test_str=response.choices[0].message.content)
|
self.test_violation(test_str=response.choices[0].message.content or "")
|
||||||
|
|
||||||
async def async_post_call_streaming_hook(
|
async def async_post_call_streaming_hook(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -122,236 +122,6 @@ async def ui_get_spend_by_tags(
|
||||||
return {"spend_per_tag": ui_tags}
|
return {"spend_per_tag": ui_tags}
|
||||||
|
|
||||||
|
|
||||||
async def view_spend_logs_from_clickhouse(
|
|
||||||
api_key=None, user_id=None, request_id=None, start_date=None, end_date=None
|
|
||||||
):
|
|
||||||
verbose_logger.debug("Reading logs from Clickhouse")
|
|
||||||
import os
|
|
||||||
|
|
||||||
# if user has setup clickhouse
|
|
||||||
# TODO: Move this to be a helper function
|
|
||||||
# querying clickhouse for this data
|
|
||||||
import clickhouse_connect
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
port = os.getenv("CLICKHOUSE_PORT")
|
|
||||||
if port is not None and isinstance(port, str):
|
|
||||||
port = int(port)
|
|
||||||
|
|
||||||
client = clickhouse_connect.get_client(
|
|
||||||
host=os.getenv("CLICKHOUSE_HOST"),
|
|
||||||
port=port,
|
|
||||||
username=os.getenv("CLICKHOUSE_USERNAME", ""),
|
|
||||||
password=os.getenv("CLICKHOUSE_PASSWORD", ""),
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
start_date is not None
|
|
||||||
and isinstance(start_date, str)
|
|
||||||
and end_date is not None
|
|
||||||
and isinstance(end_date, str)
|
|
||||||
):
|
|
||||||
# Convert the date strings to datetime objects
|
|
||||||
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
|
|
||||||
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
|
|
||||||
|
|
||||||
# get top spend per day
|
|
||||||
response = client.query(
|
|
||||||
f"""
|
|
||||||
SELECT
|
|
||||||
toDate(startTime) AS day,
|
|
||||||
sum(spend) AS total_spend
|
|
||||||
FROM
|
|
||||||
spend_logs
|
|
||||||
WHERE
|
|
||||||
toDate(startTime) BETWEEN toDate('2024-02-01') AND toDate('2024-02-29')
|
|
||||||
GROUP BY
|
|
||||||
day
|
|
||||||
ORDER BY
|
|
||||||
total_spend
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
result_rows = list(response.result_rows)
|
|
||||||
for response in result_rows:
|
|
||||||
current_row = {}
|
|
||||||
current_row["users"] = {"example": 0.0}
|
|
||||||
current_row["models"] = {}
|
|
||||||
|
|
||||||
current_row["spend"] = float(response[1])
|
|
||||||
current_row["startTime"] = str(response[0])
|
|
||||||
|
|
||||||
# stubbed api_key
|
|
||||||
current_row[""] = 0.0 # type: ignore
|
|
||||||
results.append(current_row)
|
|
||||||
|
|
||||||
return results
|
|
||||||
else:
|
|
||||||
# check if spend logs exist, if it does then return last 10 logs, sorted in descending order of startTime
|
|
||||||
response = client.query(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
*
|
|
||||||
FROM
|
|
||||||
default.spend_logs
|
|
||||||
ORDER BY
|
|
||||||
startTime DESC
|
|
||||||
LIMIT
|
|
||||||
10
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# get size of spend logs
|
|
||||||
num_rows = client.query("SELECT count(*) FROM default.spend_logs")
|
|
||||||
num_rows = num_rows.result_rows[0][0]
|
|
||||||
|
|
||||||
# safely access num_rows.result_rows[0][0]
|
|
||||||
if num_rows is None:
|
|
||||||
num_rows = 0
|
|
||||||
|
|
||||||
raw_rows = list(response.result_rows)
|
|
||||||
response_data = {
|
|
||||||
"logs": raw_rows,
|
|
||||||
"log_count": num_rows,
|
|
||||||
}
|
|
||||||
return response_data
|
|
||||||
|
|
||||||
|
|
||||||
def _create_clickhouse_material_views(client=None, table_names=[]):
|
|
||||||
# Create Materialized Views if they don't exist
|
|
||||||
# Materialized Views send new inserted rows to the aggregate tables
|
|
||||||
|
|
||||||
verbose_logger.debug("Clickhouse: Creating Materialized Views")
|
|
||||||
if "daily_aggregated_spend_per_model_mv" not in table_names:
|
|
||||||
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_model_mv")
|
|
||||||
client.command(
|
|
||||||
"""
|
|
||||||
CREATE MATERIALIZED VIEW daily_aggregated_spend_per_model_mv
|
|
||||||
TO daily_aggregated_spend_per_model
|
|
||||||
AS
|
|
||||||
SELECT
|
|
||||||
toDate(startTime) as day,
|
|
||||||
sumState(spend) AS DailySpend,
|
|
||||||
model as model
|
|
||||||
FROM spend_logs
|
|
||||||
GROUP BY
|
|
||||||
day, model
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
if "daily_aggregated_spend_per_api_key_mv" not in table_names:
|
|
||||||
verbose_logger.debug(
|
|
||||||
"Clickhouse: Creating daily_aggregated_spend_per_api_key_mv"
|
|
||||||
)
|
|
||||||
client.command(
|
|
||||||
"""
|
|
||||||
CREATE MATERIALIZED VIEW daily_aggregated_spend_per_api_key_mv
|
|
||||||
TO daily_aggregated_spend_per_api_key
|
|
||||||
AS
|
|
||||||
SELECT
|
|
||||||
toDate(startTime) as day,
|
|
||||||
sumState(spend) AS DailySpend,
|
|
||||||
api_key as api_key
|
|
||||||
FROM spend_logs
|
|
||||||
GROUP BY
|
|
||||||
day, api_key
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
if "daily_aggregated_spend_per_user_mv" not in table_names:
|
|
||||||
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_user_mv")
|
|
||||||
client.command(
|
|
||||||
"""
|
|
||||||
CREATE MATERIALIZED VIEW daily_aggregated_spend_per_user_mv
|
|
||||||
TO daily_aggregated_spend_per_user
|
|
||||||
AS
|
|
||||||
SELECT
|
|
||||||
toDate(startTime) as day,
|
|
||||||
sumState(spend) AS DailySpend,
|
|
||||||
user as user
|
|
||||||
FROM spend_logs
|
|
||||||
GROUP BY
|
|
||||||
day, user
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
if "daily_aggregated_spend_mv" not in table_names:
|
|
||||||
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_mv")
|
|
||||||
client.command(
|
|
||||||
"""
|
|
||||||
CREATE MATERIALIZED VIEW daily_aggregated_spend_mv
|
|
||||||
TO daily_aggregated_spend
|
|
||||||
AS
|
|
||||||
SELECT
|
|
||||||
toDate(startTime) as day,
|
|
||||||
sumState(spend) AS DailySpend
|
|
||||||
FROM spend_logs
|
|
||||||
GROUP BY
|
|
||||||
day
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_clickhouse_aggregate_tables(client=None, table_names=[]):
|
|
||||||
# Basic Logging works without this - this is only used for low latency reporting apis
|
|
||||||
verbose_logger.debug("Clickhouse: Creating Aggregate Tables")
|
|
||||||
|
|
||||||
# Create Aggregeate Tables if they don't exist
|
|
||||||
if "daily_aggregated_spend_per_model" not in table_names:
|
|
||||||
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_model")
|
|
||||||
client.command(
|
|
||||||
"""
|
|
||||||
CREATE TABLE daily_aggregated_spend_per_model
|
|
||||||
(
|
|
||||||
`day` Date,
|
|
||||||
`DailySpend` AggregateFunction(sum, Float64),
|
|
||||||
`model` String
|
|
||||||
)
|
|
||||||
ENGINE = SummingMergeTree()
|
|
||||||
ORDER BY (day, model);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
if "daily_aggregated_spend_per_api_key" not in table_names:
|
|
||||||
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_api_key")
|
|
||||||
client.command(
|
|
||||||
"""
|
|
||||||
CREATE TABLE daily_aggregated_spend_per_api_key
|
|
||||||
(
|
|
||||||
`day` Date,
|
|
||||||
`DailySpend` AggregateFunction(sum, Float64),
|
|
||||||
`api_key` String
|
|
||||||
)
|
|
||||||
ENGINE = SummingMergeTree()
|
|
||||||
ORDER BY (day, api_key);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
if "daily_aggregated_spend_per_user" not in table_names:
|
|
||||||
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_user")
|
|
||||||
client.command(
|
|
||||||
"""
|
|
||||||
CREATE TABLE daily_aggregated_spend_per_user
|
|
||||||
(
|
|
||||||
`day` Date,
|
|
||||||
`DailySpend` AggregateFunction(sum, Float64),
|
|
||||||
`user` String
|
|
||||||
)
|
|
||||||
ENGINE = SummingMergeTree()
|
|
||||||
ORDER BY (day, user);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
if "daily_aggregated_spend" not in table_names:
|
|
||||||
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend")
|
|
||||||
client.command(
|
|
||||||
"""
|
|
||||||
CREATE TABLE daily_aggregated_spend
|
|
||||||
(
|
|
||||||
`day` Date,
|
|
||||||
`DailySpend` AggregateFunction(sum, Float64),
|
|
||||||
)
|
|
||||||
ENGINE = SummingMergeTree()
|
|
||||||
ORDER BY (day);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def _forecast_daily_cost(data: list):
|
def _forecast_daily_cost(data: list):
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
|
@ -13,7 +13,10 @@ from litellm._logging import (
|
||||||
verbose_logger,
|
verbose_logger,
|
||||||
json_logs,
|
json_logs,
|
||||||
_turn_on_json,
|
_turn_on_json,
|
||||||
|
log_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
KeyManagementSystem,
|
KeyManagementSystem,
|
||||||
KeyManagementSettings,
|
KeyManagementSettings,
|
||||||
|
@ -34,7 +37,7 @@ input_callback: List[Union[str, Callable]] = []
|
||||||
success_callback: List[Union[str, Callable]] = []
|
success_callback: List[Union[str, Callable]] = []
|
||||||
failure_callback: List[Union[str, Callable]] = []
|
failure_callback: List[Union[str, Callable]] = []
|
||||||
service_callback: List[Union[str, Callable]] = []
|
service_callback: List[Union[str, Callable]] = []
|
||||||
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter"]
|
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"]
|
||||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||||
_langfuse_default_tags: Optional[
|
_langfuse_default_tags: Optional[
|
||||||
List[
|
List[
|
||||||
|
@ -60,6 +63,7 @@ _async_failure_callback: List[Callable] = (
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
post_call_rules: List[Callable] = []
|
post_call_rules: List[Callable] = []
|
||||||
turn_off_message_logging: Optional[bool] = False
|
turn_off_message_logging: Optional[bool] = False
|
||||||
|
log_raw_request_response: bool = False
|
||||||
redact_messages_in_exceptions: Optional[bool] = False
|
redact_messages_in_exceptions: Optional[bool] = False
|
||||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||||
## end of callbacks #############
|
## end of callbacks #############
|
||||||
|
@ -72,7 +76,7 @@ token: Optional[str] = (
|
||||||
)
|
)
|
||||||
telemetry = True
|
telemetry = True
|
||||||
max_tokens = 256 # OpenAI Defaults
|
max_tokens = 256 # OpenAI Defaults
|
||||||
drop_params = False
|
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||||
modify_params = False
|
modify_params = False
|
||||||
retry = True
|
retry = True
|
||||||
### AUTH ###
|
### AUTH ###
|
||||||
|
@ -239,6 +243,7 @@ num_retries: Optional[int] = None # per model endpoint
|
||||||
default_fallbacks: Optional[List] = None
|
default_fallbacks: Optional[List] = None
|
||||||
fallbacks: Optional[List] = None
|
fallbacks: Optional[List] = None
|
||||||
context_window_fallbacks: Optional[List] = None
|
context_window_fallbacks: Optional[List] = None
|
||||||
|
content_policy_fallbacks: Optional[List] = None
|
||||||
allowed_fails: int = 0
|
allowed_fails: int = 0
|
||||||
num_retries_per_request: Optional[int] = (
|
num_retries_per_request: Optional[int] = (
|
||||||
None # for the request overall (incl. fallbacks + model retries)
|
None # for the request overall (incl. fallbacks + model retries)
|
||||||
|
@ -336,6 +341,7 @@ bedrock_models: List = []
|
||||||
deepinfra_models: List = []
|
deepinfra_models: List = []
|
||||||
perplexity_models: List = []
|
perplexity_models: List = []
|
||||||
watsonx_models: List = []
|
watsonx_models: List = []
|
||||||
|
gemini_models: List = []
|
||||||
for key, value in model_cost.items():
|
for key, value in model_cost.items():
|
||||||
if value.get("litellm_provider") == "openai":
|
if value.get("litellm_provider") == "openai":
|
||||||
open_ai_chat_completion_models.append(key)
|
open_ai_chat_completion_models.append(key)
|
||||||
|
@ -382,13 +388,16 @@ for key, value in model_cost.items():
|
||||||
perplexity_models.append(key)
|
perplexity_models.append(key)
|
||||||
elif value.get("litellm_provider") == "watsonx":
|
elif value.get("litellm_provider") == "watsonx":
|
||||||
watsonx_models.append(key)
|
watsonx_models.append(key)
|
||||||
|
elif value.get("litellm_provider") == "gemini":
|
||||||
|
gemini_models.append(key)
|
||||||
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
||||||
openai_compatible_endpoints: List = [
|
openai_compatible_endpoints: List = [
|
||||||
"api.perplexity.ai",
|
"api.perplexity.ai",
|
||||||
"api.endpoints.anyscale.com/v1",
|
"api.endpoints.anyscale.com/v1",
|
||||||
"api.deepinfra.com/v1/openai",
|
"api.deepinfra.com/v1/openai",
|
||||||
"api.mistral.ai/v1",
|
"api.mistral.ai/v1",
|
||||||
|
"codestral.mistral.ai/v1/chat/completions",
|
||||||
|
"codestral.mistral.ai/v1/fim/completions",
|
||||||
"api.groq.com/openai/v1",
|
"api.groq.com/openai/v1",
|
||||||
"api.deepseek.com/v1",
|
"api.deepseek.com/v1",
|
||||||
"api.together.xyz/v1",
|
"api.together.xyz/v1",
|
||||||
|
@ -399,12 +408,14 @@ openai_compatible_providers: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"codestral",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"deepinfra",
|
"deepinfra",
|
||||||
"perplexity",
|
"perplexity",
|
||||||
"xinference",
|
"xinference",
|
||||||
"together_ai",
|
"together_ai",
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
|
"azure_ai",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -588,6 +599,7 @@ model_list = (
|
||||||
+ maritalk_models
|
+ maritalk_models
|
||||||
+ vertex_language_models
|
+ vertex_language_models
|
||||||
+ watsonx_models
|
+ watsonx_models
|
||||||
|
+ gemini_models
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_list: List = [
|
provider_list: List = [
|
||||||
|
@ -603,12 +615,14 @@ provider_list: List = [
|
||||||
"together_ai",
|
"together_ai",
|
||||||
"openrouter",
|
"openrouter",
|
||||||
"vertex_ai",
|
"vertex_ai",
|
||||||
|
"vertex_ai_beta",
|
||||||
"palm",
|
"palm",
|
||||||
"gemini",
|
"gemini",
|
||||||
"ai21",
|
"ai21",
|
||||||
"baseten",
|
"baseten",
|
||||||
"azure",
|
"azure",
|
||||||
"azure_text",
|
"azure_text",
|
||||||
|
"azure_ai",
|
||||||
"sagemaker",
|
"sagemaker",
|
||||||
"bedrock",
|
"bedrock",
|
||||||
"vllm",
|
"vllm",
|
||||||
|
@ -622,6 +636,8 @@ provider_list: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"codestral",
|
||||||
|
"text-completion-codestral",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"maritalk",
|
"maritalk",
|
||||||
"voyage",
|
"voyage",
|
||||||
|
@ -658,6 +674,7 @@ models_by_provider: dict = {
|
||||||
"perplexity": perplexity_models,
|
"perplexity": perplexity_models,
|
||||||
"maritalk": maritalk_models,
|
"maritalk": maritalk_models,
|
||||||
"watsonx": watsonx_models,
|
"watsonx": watsonx_models,
|
||||||
|
"gemini": gemini_models,
|
||||||
}
|
}
|
||||||
|
|
||||||
# mapping for those models which have larger equivalents
|
# mapping for those models which have larger equivalents
|
||||||
|
@ -710,6 +727,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
||||||
|
|
||||||
from .timeout import timeout
|
from .timeout import timeout
|
||||||
from .cost_calculator import completion_cost
|
from .cost_calculator import completion_cost
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from .utils import (
|
from .utils import (
|
||||||
client,
|
client,
|
||||||
exception_type,
|
exception_type,
|
||||||
|
@ -718,12 +736,11 @@ from .utils import (
|
||||||
token_counter,
|
token_counter,
|
||||||
create_pretrained_tokenizer,
|
create_pretrained_tokenizer,
|
||||||
create_tokenizer,
|
create_tokenizer,
|
||||||
cost_per_token,
|
|
||||||
supports_function_calling,
|
supports_function_calling,
|
||||||
supports_parallel_function_calling,
|
supports_parallel_function_calling,
|
||||||
supports_vision,
|
supports_vision,
|
||||||
|
supports_system_messages,
|
||||||
get_litellm_params,
|
get_litellm_params,
|
||||||
Logging,
|
|
||||||
acreate,
|
acreate,
|
||||||
get_model_list,
|
get_model_list,
|
||||||
get_max_tokens,
|
get_max_tokens,
|
||||||
|
@ -743,9 +760,10 @@ from .utils import (
|
||||||
get_first_chars_messages,
|
get_first_chars_messages,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
ImageObject,
|
|
||||||
get_provider_fields,
|
get_provider_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .types.utils import ImageObject
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
|
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
|
||||||
|
@ -762,7 +780,8 @@ from .llms.gemini import GeminiConfig
|
||||||
from .llms.nlp_cloud import NLPCloudConfig
|
from .llms.nlp_cloud import NLPCloudConfig
|
||||||
from .llms.aleph_alpha import AlephAlphaConfig
|
from .llms.aleph_alpha import AlephAlphaConfig
|
||||||
from .llms.petals import PetalsConfig
|
from .llms.petals import PetalsConfig
|
||||||
from .llms.vertex_ai import VertexAIConfig
|
from .llms.vertex_httpx import VertexGeminiConfig
|
||||||
|
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
|
||||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||||
from .llms.sagemaker import SagemakerConfig
|
from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
|
@ -784,8 +803,11 @@ from .llms.openai import (
|
||||||
OpenAIConfig,
|
OpenAIConfig,
|
||||||
OpenAITextCompletionConfig,
|
OpenAITextCompletionConfig,
|
||||||
MistralConfig,
|
MistralConfig,
|
||||||
|
MistralEmbeddingConfig,
|
||||||
DeepInfraConfig,
|
DeepInfraConfig,
|
||||||
|
AzureAIStudioConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||||
from .llms.azure import (
|
from .llms.azure import (
|
||||||
AzureOpenAIConfig,
|
AzureOpenAIConfig,
|
||||||
AzureOpenAIError,
|
AzureOpenAIError,
|
||||||
|
@ -819,4 +841,4 @@ from .router import Router
|
||||||
from .assistants.main import *
|
from .assistants.main import *
|
||||||
from .batches.main import *
|
from .batches.main import *
|
||||||
from .scheduler import *
|
from .scheduler import *
|
||||||
from .cost_calculator import response_cost_calculator
|
from .cost_calculator import response_cost_calculator, cost_per_token
|
||||||
|
|
|
@ -1,12 +1,21 @@
|
||||||
import logging, os, json
|
import json
|
||||||
from logging import Formatter
|
import logging
|
||||||
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
from logging import Formatter
|
||||||
|
|
||||||
set_verbose = False
|
set_verbose = False
|
||||||
|
|
||||||
|
if set_verbose is True:
|
||||||
|
logging.warning(
|
||||||
|
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
|
||||||
|
)
|
||||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||||
|
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
||||||
|
numeric_level: str = getattr(logging, log_level.upper())
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setLevel(logging.DEBUG)
|
handler.setLevel(numeric_level)
|
||||||
|
|
||||||
|
|
||||||
class JsonFormatter(Formatter):
|
class JsonFormatter(Formatter):
|
||||||
|
@ -14,8 +23,12 @@ class JsonFormatter(Formatter):
|
||||||
super(JsonFormatter, self).__init__()
|
super(JsonFormatter, self).__init__()
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
json_record = {}
|
json_record = {
|
||||||
json_record["message"] = record.getMessage()
|
"message": record.getMessage(),
|
||||||
|
"level": record.levelname,
|
||||||
|
"timestamp": self.formatTime(record, self.datefmt),
|
||||||
|
}
|
||||||
|
|
||||||
return json.dumps(json_record)
|
return json.dumps(json_record)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1192,7 +1192,7 @@ class S3Cache(BaseCache):
|
||||||
return cached_response
|
return cached_response
|
||||||
except botocore.exceptions.ClientError as e:
|
except botocore.exceptions.ClientError as e:
|
||||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||||
verbose_logger.error(
|
verbose_logger.debug(
|
||||||
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -1,21 +1,252 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## File for 'response_cost' calculation in Logging
|
## File for 'response_cost' calculation in Logging
|
||||||
from typing import Optional, Union, Literal, List
|
import time
|
||||||
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
import litellm._logging
|
import litellm._logging
|
||||||
|
from litellm import verbose_logger
|
||||||
|
from litellm.litellm_core_utils.llm_cost_calc.google import (
|
||||||
|
cost_per_token as google_cost_per_token,
|
||||||
|
)
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
ModelResponse,
|
CallTypes,
|
||||||
|
CostPerToken,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
TranscriptionResponse,
|
ModelResponse,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
CallTypes,
|
TranscriptionResponse,
|
||||||
cost_per_token,
|
|
||||||
print_verbose,
|
print_verbose,
|
||||||
CostPerToken,
|
|
||||||
token_counter,
|
token_counter,
|
||||||
)
|
)
|
||||||
import litellm
|
|
||||||
from litellm import verbose_logger
|
|
||||||
|
def _cost_per_token_custom_pricing_helper(
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
response_time_ms=None,
|
||||||
|
### CUSTOM PRICING ###
|
||||||
|
custom_cost_per_token: Optional[CostPerToken] = None,
|
||||||
|
custom_cost_per_second: Optional[float] = None,
|
||||||
|
) -> Optional[Tuple[float, float]]:
|
||||||
|
"""Internal helper function for calculating cost, if custom pricing given"""
|
||||||
|
if custom_cost_per_token is None and custom_cost_per_second is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if custom_cost_per_token is not None:
|
||||||
|
input_cost = custom_cost_per_token["input_cost_per_token"] * prompt_tokens
|
||||||
|
output_cost = custom_cost_per_token["output_cost_per_token"] * completion_tokens
|
||||||
|
return input_cost, output_cost
|
||||||
|
elif custom_cost_per_second is not None:
|
||||||
|
output_cost = custom_cost_per_second * response_time_ms / 1000 # type: ignore
|
||||||
|
return 0, output_cost
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def cost_per_token(
|
||||||
|
model: str = "",
|
||||||
|
prompt_tokens: float = 0,
|
||||||
|
completion_tokens: float = 0,
|
||||||
|
response_time_ms=None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
region_name=None,
|
||||||
|
### CUSTOM PRICING ###
|
||||||
|
custom_cost_per_token: Optional[CostPerToken] = None,
|
||||||
|
custom_cost_per_second: Optional[float] = None,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The name of the model to use. Default is ""
|
||||||
|
prompt_tokens (int): The number of tokens in the prompt.
|
||||||
|
completion_tokens (int): The number of tokens in the completion.
|
||||||
|
response_time (float): The amount of time, in milliseconds, it took the call to complete.
|
||||||
|
custom_llm_provider (str): The llm provider to whom the call was made (see init.py for full list)
|
||||||
|
custom_cost_per_token: Optional[CostPerToken]: the cost per input + output token for the llm api call.
|
||||||
|
custom_cost_per_second: Optional[float]: the cost per second for the llm api call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively.
|
||||||
|
"""
|
||||||
|
args = locals()
|
||||||
|
if model is None:
|
||||||
|
raise Exception("Invalid arg. Model cannot be none.")
|
||||||
|
## CUSTOM PRICING ##
|
||||||
|
response_cost = _cost_per_token_custom_pricing_helper(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
response_time_ms=response_time_ms,
|
||||||
|
custom_cost_per_second=custom_cost_per_second,
|
||||||
|
custom_cost_per_token=custom_cost_per_token,
|
||||||
|
)
|
||||||
|
if response_cost is not None:
|
||||||
|
return response_cost[0], response_cost[1]
|
||||||
|
|
||||||
|
# given
|
||||||
|
prompt_tokens_cost_usd_dollar: float = 0
|
||||||
|
completion_tokens_cost_usd_dollar: float = 0
|
||||||
|
model_cost_ref = litellm.model_cost
|
||||||
|
model_with_provider = model
|
||||||
|
if custom_llm_provider is not None:
|
||||||
|
model_with_provider = custom_llm_provider + "/" + model
|
||||||
|
if region_name is not None:
|
||||||
|
model_with_provider_and_region = (
|
||||||
|
f"{custom_llm_provider}/{region_name}/{model}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
model_with_provider_and_region in model_cost_ref
|
||||||
|
): # use region based pricing, if it's available
|
||||||
|
model_with_provider = model_with_provider_and_region
|
||||||
|
else:
|
||||||
|
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
|
||||||
|
model_without_prefix = model
|
||||||
|
model_parts = model.split("/")
|
||||||
|
if len(model_parts) > 1:
|
||||||
|
model_without_prefix = model_parts[1]
|
||||||
|
else:
|
||||||
|
model_without_prefix = model
|
||||||
|
"""
|
||||||
|
Code block that formats model to lookup in litellm.model_cost
|
||||||
|
Option1. model = "bedrock/ap-northeast-1/anthropic.claude-instant-v1". This is the most accurate since it is region based. Should always be option 1
|
||||||
|
Option2. model = "openai/gpt-4" - model = provider/model
|
||||||
|
Option3. model = "anthropic.claude-3" - model = model
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
model_with_provider in model_cost_ref
|
||||||
|
): # Option 2. use model with provider, model = "openai/gpt-4"
|
||||||
|
model = model_with_provider
|
||||||
|
elif model in model_cost_ref: # Option 1. use model passed, model="gpt-4"
|
||||||
|
model = model
|
||||||
|
elif (
|
||||||
|
model_without_prefix in model_cost_ref
|
||||||
|
): # Option 3. if user passed model="bedrock/anthropic.claude-3", use model="anthropic.claude-3"
|
||||||
|
model = model_without_prefix
|
||||||
|
|
||||||
|
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
||||||
|
print_verbose(f"Looking up model={model} in model_cost_map")
|
||||||
|
if custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini":
|
||||||
|
return google_cost_per_token(
|
||||||
|
model=model_without_prefix,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
elif model in model_cost_ref:
|
||||||
|
print_verbose(f"Success: model={model} in model_cost_map")
|
||||||
|
print_verbose(
|
||||||
|
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
model_cost_ref[model].get("input_cost_per_token", None) is not None
|
||||||
|
and model_cost_ref[model].get("output_cost_per_token", None) is not None
|
||||||
|
):
|
||||||
|
## COST PER TOKEN ##
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
model_cost_ref[model].get("output_cost_per_second", None) is not None
|
||||||
|
and response_time_ms is not None
|
||||||
|
):
|
||||||
|
print_verbose(
|
||||||
|
f"For model={model} - output_cost_per_second: {model_cost_ref[model].get('output_cost_per_second')}; response time: {response_time_ms}"
|
||||||
|
)
|
||||||
|
## COST PER SECOND ##
|
||||||
|
prompt_tokens_cost_usd_dollar = 0
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["output_cost_per_second"]
|
||||||
|
* response_time_ms
|
||||||
|
/ 1000
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
model_cost_ref[model].get("input_cost_per_second", None) is not None
|
||||||
|
and response_time_ms is not None
|
||||||
|
):
|
||||||
|
print_verbose(
|
||||||
|
f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}"
|
||||||
|
)
|
||||||
|
## COST PER SECOND ##
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = 0.0
|
||||||
|
print_verbose(
|
||||||
|
f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
|
||||||
|
)
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
|
elif "ft:gpt-3.5-turbo" in model:
|
||||||
|
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||||
|
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"]
|
||||||
|
* completion_tokens
|
||||||
|
)
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
|
elif "ft:davinci-002" in model:
|
||||||
|
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||||
|
# fuzzy match ft:davinci-002:abcd-id-cool-litellm
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref["ft:davinci-002"]["input_cost_per_token"] * prompt_tokens
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref["ft:davinci-002"]["output_cost_per_token"]
|
||||||
|
* completion_tokens
|
||||||
|
)
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
|
elif "ft:babbage-002" in model:
|
||||||
|
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||||
|
# fuzzy match ft:babbage-002:abcd-id-cool-litellm
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref["ft:babbage-002"]["input_cost_per_token"] * prompt_tokens
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref["ft:babbage-002"]["output_cost_per_token"]
|
||||||
|
* completion_tokens
|
||||||
|
)
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
|
elif model in litellm.azure_llms:
|
||||||
|
verbose_logger.debug(f"Cost Tracking: {model} is an Azure LLM")
|
||||||
|
model = litellm.azure_llms[model]
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
|
||||||
|
)
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}"
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||||
|
)
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
|
elif model in litellm.azure_embedding_models:
|
||||||
|
verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model")
|
||||||
|
model = litellm.azure_embedding_models[model]
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||||
|
)
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
|
else:
|
||||||
|
# if model is not in model_prices_and_context_window.json. Raise an exception-let users know
|
||||||
|
error_str = f"Model not in model_prices_and_context_window.json. You passed model={model}. Register pricing for model - https://docs.litellm.ai/docs/proxy/custom_pricing\n"
|
||||||
|
raise litellm.exceptions.NotFoundError( # type: ignore
|
||||||
|
message=error_str,
|
||||||
|
model=model,
|
||||||
|
llm_provider="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Extract the number of billion parameters from the model name
|
# Extract the number of billion parameters from the model name
|
||||||
|
@ -337,8 +568,6 @@ def response_cost_calculator(
|
||||||
and custom_llm_provider is True
|
and custom_llm_provider is True
|
||||||
): # override defaults if custom pricing is set
|
): # override defaults if custom pricing is set
|
||||||
base_model = model
|
base_model = model
|
||||||
elif base_model is None:
|
|
||||||
base_model = model
|
|
||||||
# base_model defaults to None if not set on model_info
|
# base_model defaults to None if not set on model_info
|
||||||
response_cost = completion_cost(
|
response_cost = completion_cost(
|
||||||
completion_response=response_object,
|
completion_response=response_object,
|
||||||
|
|
|
@ -26,7 +26,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 401
|
self.status_code = 401
|
||||||
self.message = message
|
self.message = "litellm.AuthenticationError: {}".format(message)
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -72,7 +72,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 404
|
self.status_code = 404
|
||||||
self.message = message
|
self.message = "litellm.NotFoundError: {}".format(message)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -117,7 +117,7 @@ class BadRequestError(openai.BadRequestError): # type: ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 400
|
self.status_code = 400
|
||||||
self.message = message
|
self.message = "litellm.BadRequestError: {}".format(message)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -162,7 +162,7 @@ class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 422
|
self.status_code = 422
|
||||||
self.message = message
|
self.message = "litellm.UnprocessableEntityError: {}".format(message)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -204,7 +204,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
|
||||||
request=request
|
request=request
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
self.status_code = 408
|
self.status_code = 408
|
||||||
self.message = message
|
self.message = "litellm.Timeout: {}".format(message)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -241,7 +241,7 @@ class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 403
|
self.status_code = 403
|
||||||
self.message = message
|
self.message = "litellm.PermissionDeniedError: {}".format(message)
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -280,7 +280,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 429
|
self.status_code = 429
|
||||||
self.message = message
|
self.message = "litellm.RateLimitError: {}".format(message)
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -324,19 +324,22 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||||
message,
|
message,
|
||||||
model,
|
model,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
response: httpx.Response,
|
response: Optional[httpx.Response] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 400
|
self.status_code = 400
|
||||||
self.message = message
|
self.message = "litellm.ContextWindowExceededError: {}".format(message)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||||
|
self.response = response or httpx.Response(status_code=400, request=request)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=self.message,
|
message=self.message,
|
||||||
model=self.model, # type: ignore
|
model=self.model, # type: ignore
|
||||||
llm_provider=self.llm_provider, # type: ignore
|
llm_provider=self.llm_provider, # type: ignore
|
||||||
response=response,
|
response=self.response,
|
||||||
|
litellm_debug_info=self.litellm_debug_info,
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -367,7 +370,7 @@ class RejectedRequestError(BadRequestError): # type: ignore
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 400
|
self.status_code = 400
|
||||||
self.message = message
|
self.message = "litellm.RejectedRequestError: {}".format(message)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -379,6 +382,7 @@ class RejectedRequestError(BadRequestError): # type: ignore
|
||||||
model=self.model, # type: ignore
|
model=self.model, # type: ignore
|
||||||
llm_provider=self.llm_provider, # type: ignore
|
llm_provider=self.llm_provider, # type: ignore
|
||||||
response=response,
|
response=response,
|
||||||
|
litellm_debug_info=self.litellm_debug_info,
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -405,19 +409,22 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||||
message,
|
message,
|
||||||
model,
|
model,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
response: httpx.Response,
|
response: Optional[httpx.Response] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 400
|
self.status_code = 400
|
||||||
self.message = message
|
self.message = "litellm.ContentPolicyViolationError: {}".format(message)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||||
|
self.response = response or httpx.Response(status_code=500, request=request)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=self.message,
|
message=self.message,
|
||||||
model=self.model, # type: ignore
|
model=self.model, # type: ignore
|
||||||
llm_provider=self.llm_provider, # type: ignore
|
llm_provider=self.llm_provider, # type: ignore
|
||||||
response=response,
|
response=self.response,
|
||||||
|
litellm_debug_info=self.litellm_debug_info,
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -449,7 +456,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 503
|
self.status_code = 503
|
||||||
self.message = message
|
self.message = "litellm.ServiceUnavailableError: {}".format(message)
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -498,7 +505,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = 500
|
self.status_code = 500
|
||||||
self.message = message
|
self.message = "litellm.InternalServerError: {}".format(message)
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -549,7 +556,7 @@ class APIError(openai.APIError): # type: ignore
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = "litellm.APIError: {}".format(message)
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
@ -586,7 +593,7 @@ class APIConnectionError(openai.APIConnectionError): # type: ignore
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.message = message
|
self.message = "litellm.APIConnectionError: {}".format(message)
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.status_code = 500
|
self.status_code = 500
|
||||||
|
@ -623,7 +630,7 @@ class APIResponseValidationError(openai.APIResponseValidationError): # type: ig
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.message = message
|
self.message = "litellm.APIResponseValidationError: {}".format(message)
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
self.model = model
|
self.model = model
|
||||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||||
|
|
|
@ -226,14 +226,6 @@ def _start_clickhouse():
|
||||||
response = client.query("DESCRIBE default.spend_logs")
|
response = client.query("DESCRIBE default.spend_logs")
|
||||||
verbose_logger.debug(f"spend logs schema ={response.result_rows}")
|
verbose_logger.debug(f"spend logs schema ={response.result_rows}")
|
||||||
|
|
||||||
# RUN Enterprise Clickhouse Setup
|
|
||||||
# TLDR: For Enterprise - we create views / aggregate tables for low latency reporting APIs
|
|
||||||
from litellm.proxy.enterprise.utils import _create_clickhouse_aggregate_tables
|
|
||||||
from litellm.proxy.enterprise.utils import _create_clickhouse_material_views
|
|
||||||
|
|
||||||
_create_clickhouse_aggregate_tables(client=client, table_names=table_names)
|
|
||||||
_create_clickhouse_material_views(client=client, table_names=table_names)
|
|
||||||
|
|
||||||
|
|
||||||
class ClickhouseLogger:
|
class ClickhouseLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
|
|
|
@ -10,7 +10,7 @@ import traceback
|
||||||
|
|
||||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
|
136
litellm/integrations/email_alerting.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
"""
|
||||||
|
Functions for sending Email Alerts
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional, List
|
||||||
|
from litellm.proxy._types import WebhookEvent
|
||||||
|
import asyncio
|
||||||
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
|
|
||||||
|
# we use this for the email header, please send a test email if you change this. verify it looks good on email
|
||||||
|
LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||||
|
LITELLM_SUPPORT_CONTACT = "support@berri.ai"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Email Alerting: Getting all team members for team_id=%s", team_id
|
||||||
|
)
|
||||||
|
if team_id is None:
|
||||||
|
return []
|
||||||
|
from litellm.proxy.proxy_server import premium_user, prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception("Not connected to DB!")
|
||||||
|
|
||||||
|
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
|
where={
|
||||||
|
"team_id": team_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if team_row is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
_team_members = team_row.members_with_roles
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Email Alerting: Got team members for team_id=%s Team Members: %s",
|
||||||
|
team_id,
|
||||||
|
_team_members,
|
||||||
|
)
|
||||||
|
_team_member_user_ids: List[str] = []
|
||||||
|
for member in _team_members:
|
||||||
|
if member and isinstance(member, dict) and member.get("user_id") is not None:
|
||||||
|
_team_member_user_ids.append(member.get("user_id"))
|
||||||
|
|
||||||
|
sql_query = """
|
||||||
|
SELECT user_email
|
||||||
|
FROM "LiteLLM_UserTable"
|
||||||
|
WHERE user_id = ANY($1::TEXT[]);
|
||||||
|
"""
|
||||||
|
|
||||||
|
_result = await prisma_client.db.query_raw(sql_query, _team_member_user_ids)
|
||||||
|
|
||||||
|
verbose_logger.debug("Email Alerting: Got all Emails for team, emails=%s", _result)
|
||||||
|
|
||||||
|
if _result is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
emails = []
|
||||||
|
for user in _result:
|
||||||
|
if user and isinstance(user, dict) and user.get("user_email", None) is not None:
|
||||||
|
emails.append(user.get("user_email"))
|
||||||
|
return emails
|
||||||
|
|
||||||
|
|
||||||
|
async def send_team_budget_alert(webhook_event: WebhookEvent) -> bool:
|
||||||
|
"""
|
||||||
|
Send an Email Alert to All Team Members when the Team Budget is crossed
|
||||||
|
Returns -> True if sent, False if not.
|
||||||
|
"""
|
||||||
|
from litellm.proxy.utils import send_email
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import premium_user, prisma_client
|
||||||
|
|
||||||
|
_team_id = webhook_event.team_id
|
||||||
|
team_alias = webhook_event.team_alias
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Email Alerting: Sending Team Budget Alert for team=%s", team_alias
|
||||||
|
)
|
||||||
|
|
||||||
|
email_logo_url = os.getenv("SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None))
|
||||||
|
email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
|
||||||
|
|
||||||
|
# await self._check_if_using_premium_email_feature(
|
||||||
|
# premium_user, email_logo_url, email_support_contact
|
||||||
|
# )
|
||||||
|
|
||||||
|
if email_logo_url is None:
|
||||||
|
email_logo_url = LITELLM_LOGO_URL
|
||||||
|
if email_support_contact is None:
|
||||||
|
email_support_contact = LITELLM_SUPPORT_CONTACT
|
||||||
|
recipient_emails = await get_all_team_member_emails(_team_id)
|
||||||
|
recipient_emails_str: str = ",".join(recipient_emails)
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Email Alerting: Sending team budget alert to %s", recipient_emails_str
|
||||||
|
)
|
||||||
|
|
||||||
|
event_name = webhook_event.event_message
|
||||||
|
max_budget = webhook_event.max_budget
|
||||||
|
email_html_content = "Alert from LiteLLM Server"
|
||||||
|
|
||||||
|
if recipient_emails_str is None:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
"Email Alerting: Trying to send email alert to no recipient, got recipient_emails=%s",
|
||||||
|
recipient_emails_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
email_html_content = f"""
|
||||||
|
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> <br/><br/><br/>
|
||||||
|
|
||||||
|
Budget Crossed for Team <b> {team_alias} </b> <br/> <br/>
|
||||||
|
|
||||||
|
Your Teams LLM API usage has crossed it's <b> budget of ${max_budget} </b>, current spend is <b>${webhook_event.spend}</b><br /> <br />
|
||||||
|
|
||||||
|
API requests will be rejected until either (a) you increase your budget or (b) your budget gets reset <br /> <br />
|
||||||
|
|
||||||
|
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
||||||
|
|
||||||
|
Best, <br />
|
||||||
|
The LiteLLM team <br />
|
||||||
|
"""
|
||||||
|
|
||||||
|
email_event = {
|
||||||
|
"to": recipient_emails_str,
|
||||||
|
"subject": f"LiteLLM {event_name} for Team {team_alias}",
|
||||||
|
"html": email_html_content,
|
||||||
|
}
|
||||||
|
|
||||||
|
await send_email(
|
||||||
|
receiver_email=email_event["to"],
|
||||||
|
subject=email_event["subject"],
|
||||||
|
html=email_event["html"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
|
@ -1,13 +1,19 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
|
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
|
||||||
|
|
||||||
import dotenv, os, json
|
import json
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import traceback, httpx
|
from litellm import verbose_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
import uuid
|
|
||||||
from typing import Optional, Literal
|
|
||||||
|
|
||||||
|
|
||||||
def get_utc_datetime():
|
def get_utc_datetime():
|
||||||
|
@ -143,6 +149,7 @@ class LagoLogger(CustomLogger):
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
|
verbose_logger.debug("ENTERS LAGO CALLBACK")
|
||||||
_url = os.getenv("LAGO_API_BASE")
|
_url = os.getenv("LAGO_API_BASE")
|
||||||
assert _url is not None and isinstance(
|
assert _url is not None and isinstance(
|
||||||
_url, str
|
_url, str
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langfuse
|
# On success, logs events to Langfuse
|
||||||
import os
|
|
||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from litellm._logging import verbose_logger
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
|
||||||
class LangFuseLogger:
|
class LangFuseLogger:
|
||||||
|
@ -14,8 +16,8 @@ class LangFuseLogger:
|
||||||
self, langfuse_public_key=None, langfuse_secret=None, flush_interval=1
|
self, langfuse_public_key=None, langfuse_secret=None, flush_interval=1
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
from langfuse import Langfuse
|
|
||||||
import langfuse
|
import langfuse
|
||||||
|
from langfuse import Langfuse
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||||
|
@ -167,7 +169,7 @@ class LangFuseLogger:
|
||||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||||
):
|
):
|
||||||
input = prompt
|
input = prompt
|
||||||
output = response_obj["data"]
|
output = None
|
||||||
elif response_obj is not None and isinstance(
|
elif response_obj is not None and isinstance(
|
||||||
response_obj, litellm.ModelResponse
|
response_obj, litellm.ModelResponse
|
||||||
):
|
):
|
||||||
|
@ -251,7 +253,7 @@ class LangFuseLogger:
|
||||||
input,
|
input,
|
||||||
response_obj,
|
response_obj,
|
||||||
):
|
):
|
||||||
from langfuse.model import CreateTrace, CreateGeneration
|
from langfuse.model import CreateGeneration, CreateTrace
|
||||||
|
|
||||||
verbose_logger.warning(
|
verbose_logger.warning(
|
||||||
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
|
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
|
||||||
|
@ -528,31 +530,14 @@ class LangFuseLogger:
|
||||||
"version": clean_metadata.pop("version", None),
|
"version": clean_metadata.pop("version", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parent_observation_id = metadata.get("parent_observation_id", None)
|
||||||
|
if parent_observation_id is not None:
|
||||||
|
generation_params["parent_observation_id"] = parent_observation_id
|
||||||
|
|
||||||
if supports_prompt:
|
if supports_prompt:
|
||||||
user_prompt = clean_metadata.pop("prompt", None)
|
generation_params = _add_prompt_to_generation_params(
|
||||||
if user_prompt is None:
|
generation_params=generation_params, clean_metadata=clean_metadata
|
||||||
pass
|
)
|
||||||
elif isinstance(user_prompt, dict):
|
|
||||||
from langfuse.model import (
|
|
||||||
TextPromptClient,
|
|
||||||
ChatPromptClient,
|
|
||||||
Prompt_Text,
|
|
||||||
Prompt_Chat,
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_prompt.get("type", "") == "chat":
|
|
||||||
_prompt_chat = Prompt_Chat(**user_prompt)
|
|
||||||
generation_params["prompt"] = ChatPromptClient(
|
|
||||||
prompt=_prompt_chat
|
|
||||||
)
|
|
||||||
elif user_prompt.get("type", "") == "text":
|
|
||||||
_prompt_text = Prompt_Text(**user_prompt)
|
|
||||||
generation_params["prompt"] = TextPromptClient(
|
|
||||||
prompt=_prompt_text
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
generation_params["prompt"] = user_prompt
|
|
||||||
|
|
||||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||||
generation_params["status_message"] = output
|
generation_params["status_message"] = output
|
||||||
|
|
||||||
|
@ -565,5 +550,58 @@ class LangFuseLogger:
|
||||||
|
|
||||||
return generation_client.trace_id, generation_id
|
return generation_client.trace_id, generation_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _add_prompt_to_generation_params(
|
||||||
|
generation_params: dict, clean_metadata: dict
|
||||||
|
) -> dict:
|
||||||
|
from langfuse.model import (
|
||||||
|
ChatPromptClient,
|
||||||
|
Prompt_Chat,
|
||||||
|
Prompt_Text,
|
||||||
|
TextPromptClient,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt = clean_metadata.pop("prompt", None)
|
||||||
|
if user_prompt is None:
|
||||||
|
pass
|
||||||
|
elif isinstance(user_prompt, dict):
|
||||||
|
if user_prompt.get("type", "") == "chat":
|
||||||
|
_prompt_chat = Prompt_Chat(**user_prompt)
|
||||||
|
generation_params["prompt"] = ChatPromptClient(prompt=_prompt_chat)
|
||||||
|
elif user_prompt.get("type", "") == "text":
|
||||||
|
_prompt_text = Prompt_Text(**user_prompt)
|
||||||
|
generation_params["prompt"] = TextPromptClient(prompt=_prompt_text)
|
||||||
|
elif "version" in user_prompt and "prompt" in user_prompt:
|
||||||
|
# prompts
|
||||||
|
if isinstance(user_prompt["prompt"], str):
|
||||||
|
_prompt_obj = Prompt_Text(
|
||||||
|
name=user_prompt["name"],
|
||||||
|
prompt=user_prompt["prompt"],
|
||||||
|
version=user_prompt["version"],
|
||||||
|
config=user_prompt.get("config", None),
|
||||||
|
)
|
||||||
|
generation_params["prompt"] = TextPromptClient(prompt=_prompt_obj)
|
||||||
|
|
||||||
|
elif isinstance(user_prompt["prompt"], list):
|
||||||
|
_prompt_obj = Prompt_Chat(
|
||||||
|
name=user_prompt["name"],
|
||||||
|
prompt=user_prompt["prompt"],
|
||||||
|
version=user_prompt["version"],
|
||||||
|
config=user_prompt.get("config", None),
|
||||||
|
)
|
||||||
|
generation_params["prompt"] = ChatPromptClient(prompt=_prompt_obj)
|
||||||
|
else:
|
||||||
|
verbose_logger.error(
|
||||||
|
"[Non-blocking] Langfuse Logger: Invalid prompt format"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_logger.error(
|
||||||
|
"[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generation_params["prompt"] = user_prompt
|
||||||
|
|
||||||
|
return generation_params
|
||||||
|
|
|
@ -105,7 +105,6 @@ class LunaryLogger:
|
||||||
end_time=datetime.now(timezone.utc),
|
end_time=datetime.now(timezone.utc),
|
||||||
error=None,
|
error=None,
|
||||||
):
|
):
|
||||||
# Method definition
|
|
||||||
try:
|
try:
|
||||||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||||
|
|
||||||
|
@ -114,10 +113,9 @@ class LunaryLogger:
|
||||||
metadata = litellm_params.get("metadata", {}) or {}
|
metadata = litellm_params.get("metadata", {}) or {}
|
||||||
|
|
||||||
if optional_params:
|
if optional_params:
|
||||||
# merge into extra
|
|
||||||
extra = {**extra, **optional_params}
|
extra = {**extra, **optional_params}
|
||||||
|
|
||||||
tags = litellm_params.pop("tags", None) or []
|
tags = metadata.get("tags", None)
|
||||||
|
|
||||||
if extra:
|
if extra:
|
||||||
extra.pop("extra_body", None)
|
extra.pop("extra_body", None)
|
||||||
|
|
|
@ -1,22 +1,29 @@
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import litellm
|
from functools import wraps
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.types.services import ServiceLoggerPayload
|
from litellm.types.services import ServiceLoggerPayload
|
||||||
from typing import Union, Optional, TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
ManagementEndpointLoggingPayload as _ManagementEndpointLoggingPayload,
|
||||||
|
)
|
||||||
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
||||||
|
|
||||||
Span = _Span
|
Span = _Span
|
||||||
UserAPIKeyAuth = _UserAPIKeyAuth
|
UserAPIKeyAuth = _UserAPIKeyAuth
|
||||||
|
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
UserAPIKeyAuth = Any
|
UserAPIKeyAuth = Any
|
||||||
|
ManagementEndpointLoggingPayload = Any
|
||||||
|
|
||||||
|
|
||||||
LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
|
LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
|
||||||
|
@ -101,8 +108,9 @@ class OpenTelemetry(CustomLogger):
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[datetime] = None,
|
||||||
end_time: Optional[datetime] = None,
|
end_time: Optional[datetime] = None,
|
||||||
):
|
):
|
||||||
from opentelemetry import trace
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
_start_time_ns = start_time
|
_start_time_ns = start_time
|
||||||
|
@ -139,8 +147,9 @@ class OpenTelemetry(CustomLogger):
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[datetime] = None,
|
||||||
end_time: Optional[datetime] = None,
|
end_time: Optional[datetime] = None,
|
||||||
):
|
):
|
||||||
from opentelemetry import trace
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
_start_time_ns = start_time
|
_start_time_ns = start_time
|
||||||
|
@ -173,8 +182,8 @@ class OpenTelemetry(CustomLogger):
|
||||||
async def async_post_call_failure_hook(
|
async def async_post_call_failure_hook(
|
||||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||||
):
|
):
|
||||||
from opentelemetry.trace import Status, StatusCode
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||||
if parent_otel_span is not None:
|
if parent_otel_span is not None:
|
||||||
|
@ -196,8 +205,8 @@ class OpenTelemetry(CustomLogger):
|
||||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||||
|
|
||||||
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
|
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
|
||||||
from opentelemetry.trace import Status, StatusCode
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"OpenTelemetry Logger: Logging kwargs: %s, OTEL config settings=%s",
|
"OpenTelemetry Logger: Logging kwargs: %s, OTEL config settings=%s",
|
||||||
|
@ -247,9 +256,10 @@ class OpenTelemetry(CustomLogger):
|
||||||
span.end(end_time=self._to_ns(end_time))
|
span.end(end_time=self._to_ns(end_time))
|
||||||
|
|
||||||
def set_tools_attributes(self, span: Span, tools):
|
def set_tools_attributes(self, span: Span, tools):
|
||||||
from opentelemetry.semconv.ai import SpanAttributes
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from litellm.proxy._types import SpanAttributes
|
||||||
|
|
||||||
if not tools:
|
if not tools:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -272,7 +282,7 @@ class OpenTelemetry(CustomLogger):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_attributes(self, span: Span, kwargs, response_obj):
|
def set_attributes(self, span: Span, kwargs, response_obj):
|
||||||
from opentelemetry.semconv.ai import SpanAttributes
|
from litellm.proxy._types import SpanAttributes
|
||||||
|
|
||||||
optional_params = kwargs.get("optional_params", {})
|
optional_params = kwargs.get("optional_params", {})
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
|
@ -314,7 +324,7 @@ class OpenTelemetry(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
SpanAttributes.LLM_IS_STREAMING, optional_params.get("stream", False)
|
SpanAttributes.LLM_IS_STREAMING, str(optional_params.get("stream", False))
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("tools"):
|
if optional_params.get("tools"):
|
||||||
|
@ -407,7 +417,7 @@ class OpenTelemetry(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
|
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
|
||||||
from opentelemetry.semconv.ai import SpanAttributes
|
from litellm.proxy._types import SpanAttributes
|
||||||
|
|
||||||
optional_params = kwargs.get("optional_params", {})
|
optional_params = kwargs.get("optional_params", {})
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
|
@ -433,7 +443,7 @@ class OpenTelemetry(CustomLogger):
|
||||||
#############################################
|
#############################################
|
||||||
########## LLM Response Attributes ##########
|
########## LLM Response Attributes ##########
|
||||||
#############################################
|
#############################################
|
||||||
if _raw_response:
|
if _raw_response and isinstance(_raw_response, str):
|
||||||
# cast sr -> dict
|
# cast sr -> dict
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -454,11 +464,28 @@ class OpenTelemetry(CustomLogger):
|
||||||
def _get_span_name(self, kwargs):
|
def _get_span_name(self, kwargs):
|
||||||
return LITELLM_REQUEST_SPAN_NAME
|
return LITELLM_REQUEST_SPAN_NAME
|
||||||
|
|
||||||
def _get_span_context(self, kwargs):
|
def get_traceparent_from_header(self, headers):
|
||||||
|
if headers is None:
|
||||||
|
return None
|
||||||
|
_traceparent = headers.get("traceparent", None)
|
||||||
|
if _traceparent is None:
|
||||||
|
return None
|
||||||
|
|
||||||
from opentelemetry.trace.propagation.tracecontext import (
|
from opentelemetry.trace.propagation.tracecontext import (
|
||||||
TraceContextTextMapPropagator,
|
TraceContextTextMapPropagator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
verbose_logger.debug("OpenTelemetry: GOT A TRACEPARENT {}".format(_traceparent))
|
||||||
|
propagator = TraceContextTextMapPropagator()
|
||||||
|
_parent_context = propagator.extract(carrier={"traceparent": _traceparent})
|
||||||
|
verbose_logger.debug("OpenTelemetry: PARENT CONTEXT {}".format(_parent_context))
|
||||||
|
return _parent_context
|
||||||
|
|
||||||
|
def _get_span_context(self, kwargs):
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
from opentelemetry.trace.propagation.tracecontext import (
|
||||||
|
TraceContextTextMapPropagator,
|
||||||
|
)
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
|
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
|
||||||
|
@ -482,17 +509,17 @@ class OpenTelemetry(CustomLogger):
|
||||||
return TraceContextTextMapPropagator().extract(carrier=carrier), None
|
return TraceContextTextMapPropagator().extract(carrier=carrier), None
|
||||||
|
|
||||||
def _get_span_processor(self):
|
def _get_span_processor(self):
|
||||||
from opentelemetry.sdk.trace.export import (
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
||||||
SpanExporter,
|
OTLPSpanExporter as OTLPSpanExporterGRPC,
|
||||||
SimpleSpanProcessor,
|
|
||||||
BatchSpanProcessor,
|
|
||||||
ConsoleSpanExporter,
|
|
||||||
)
|
)
|
||||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||||
OTLPSpanExporter as OTLPSpanExporterHTTP,
|
OTLPSpanExporter as OTLPSpanExporterHTTP,
|
||||||
)
|
)
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
from opentelemetry.sdk.trace.export import (
|
||||||
OTLPSpanExporter as OTLPSpanExporterGRPC,
|
BatchSpanProcessor,
|
||||||
|
ConsoleSpanExporter,
|
||||||
|
SimpleSpanProcessor,
|
||||||
|
SpanExporter,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -545,3 +572,93 @@ class OpenTelemetry(CustomLogger):
|
||||||
self.OTEL_EXPORTER,
|
self.OTEL_EXPORTER,
|
||||||
)
|
)
|
||||||
return BatchSpanProcessor(ConsoleSpanExporter())
|
return BatchSpanProcessor(ConsoleSpanExporter())
|
||||||
|
|
||||||
|
async def async_management_endpoint_success_hook(
|
||||||
|
self,
|
||||||
|
logging_payload: ManagementEndpointLoggingPayload,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
):
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
_start_time_ns = logging_payload.start_time
|
||||||
|
_end_time_ns = logging_payload.end_time
|
||||||
|
|
||||||
|
start_time = logging_payload.start_time
|
||||||
|
end_time = logging_payload.end_time
|
||||||
|
|
||||||
|
if isinstance(start_time, float):
|
||||||
|
_start_time_ns = int(int(start_time) * 1e9)
|
||||||
|
else:
|
||||||
|
_start_time_ns = self._to_ns(start_time)
|
||||||
|
|
||||||
|
if isinstance(end_time, float):
|
||||||
|
_end_time_ns = int(int(end_time) * 1e9)
|
||||||
|
else:
|
||||||
|
_end_time_ns = self._to_ns(end_time)
|
||||||
|
|
||||||
|
if parent_otel_span is not None:
|
||||||
|
_span_name = logging_payload.route
|
||||||
|
management_endpoint_span = self.tracer.start_span(
|
||||||
|
name=_span_name,
|
||||||
|
context=trace.set_span_in_context(parent_otel_span),
|
||||||
|
start_time=_start_time_ns,
|
||||||
|
)
|
||||||
|
|
||||||
|
_request_data = logging_payload.request_data
|
||||||
|
if _request_data is not None:
|
||||||
|
for key, value in _request_data.items():
|
||||||
|
management_endpoint_span.set_attribute(f"request.{key}", value)
|
||||||
|
|
||||||
|
_response = logging_payload.response
|
||||||
|
if _response is not None:
|
||||||
|
for key, value in _response.items():
|
||||||
|
management_endpoint_span.set_attribute(f"response.{key}", value)
|
||||||
|
management_endpoint_span.set_status(Status(StatusCode.OK))
|
||||||
|
management_endpoint_span.end(end_time=_end_time_ns)
|
||||||
|
|
||||||
|
async def async_management_endpoint_failure_hook(
|
||||||
|
self,
|
||||||
|
logging_payload: ManagementEndpointLoggingPayload,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
):
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
_start_time_ns = logging_payload.start_time
|
||||||
|
_end_time_ns = logging_payload.end_time
|
||||||
|
|
||||||
|
start_time = logging_payload.start_time
|
||||||
|
end_time = logging_payload.end_time
|
||||||
|
|
||||||
|
if isinstance(start_time, float):
|
||||||
|
_start_time_ns = int(int(start_time) * 1e9)
|
||||||
|
else:
|
||||||
|
_start_time_ns = self._to_ns(start_time)
|
||||||
|
|
||||||
|
if isinstance(end_time, float):
|
||||||
|
_end_time_ns = int(int(end_time) * 1e9)
|
||||||
|
else:
|
||||||
|
_end_time_ns = self._to_ns(end_time)
|
||||||
|
|
||||||
|
if parent_otel_span is not None:
|
||||||
|
_span_name = logging_payload.route
|
||||||
|
management_endpoint_span = self.tracer.start_span(
|
||||||
|
name=_span_name,
|
||||||
|
context=trace.set_span_in_context(parent_otel_span),
|
||||||
|
start_time=_start_time_ns,
|
||||||
|
)
|
||||||
|
|
||||||
|
_request_data = logging_payload.request_data
|
||||||
|
if _request_data is not None:
|
||||||
|
for key, value in _request_data.items():
|
||||||
|
management_endpoint_span.set_attribute(f"request.{key}", value)
|
||||||
|
|
||||||
|
_exception = logging_payload.exception
|
||||||
|
management_endpoint_span.set_attribute(f"exception", str(_exception))
|
||||||
|
management_endpoint_span.set_status(Status(StatusCode.ERROR))
|
||||||
|
management_endpoint_span.end(end_time=_end_time_ns)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class PrometheusLogger:
|
class PrometheusLogger:
|
||||||
|
@ -17,33 +18,76 @@ class PrometheusLogger:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter, Gauge
|
||||||
|
|
||||||
self.litellm_llm_api_failed_requests_metric = Counter(
|
self.litellm_llm_api_failed_requests_metric = Counter(
|
||||||
name="litellm_llm_api_failed_requests_metric",
|
name="litellm_llm_api_failed_requests_metric",
|
||||||
documentation="Total number of failed LLM API calls via litellm",
|
documentation="Total number of failed LLM API calls via litellm",
|
||||||
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
|
labelnames=[
|
||||||
|
"end_user",
|
||||||
|
"hashed_api_key",
|
||||||
|
"model",
|
||||||
|
"team",
|
||||||
|
"team_alias",
|
||||||
|
"user",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.litellm_requests_metric = Counter(
|
self.litellm_requests_metric = Counter(
|
||||||
name="litellm_requests_metric",
|
name="litellm_requests_metric",
|
||||||
documentation="Total number of LLM calls to litellm",
|
documentation="Total number of LLM calls to litellm",
|
||||||
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
|
labelnames=[
|
||||||
|
"end_user",
|
||||||
|
"hashed_api_key",
|
||||||
|
"model",
|
||||||
|
"team",
|
||||||
|
"team_alias",
|
||||||
|
"user",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Counter for spend
|
# Counter for spend
|
||||||
self.litellm_spend_metric = Counter(
|
self.litellm_spend_metric = Counter(
|
||||||
"litellm_spend_metric",
|
"litellm_spend_metric",
|
||||||
"Total spend on LLM requests",
|
"Total spend on LLM requests",
|
||||||
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
|
labelnames=[
|
||||||
|
"end_user",
|
||||||
|
"hashed_api_key",
|
||||||
|
"model",
|
||||||
|
"team",
|
||||||
|
"team_alias",
|
||||||
|
"user",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Counter for total_output_tokens
|
# Counter for total_output_tokens
|
||||||
self.litellm_tokens_metric = Counter(
|
self.litellm_tokens_metric = Counter(
|
||||||
"litellm_total_tokens",
|
"litellm_total_tokens",
|
||||||
"Total number of input + output tokens from LLM requests",
|
"Total number of input + output tokens from LLM requests",
|
||||||
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
|
labelnames=[
|
||||||
|
"end_user",
|
||||||
|
"hashed_api_key",
|
||||||
|
"model",
|
||||||
|
"team",
|
||||||
|
"team_alias",
|
||||||
|
"user",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remaining Budget for Team
|
||||||
|
self.litellm_remaining_team_budget_metric = Gauge(
|
||||||
|
"litellm_remaining_team_budget_metric",
|
||||||
|
"Remaining budget for team",
|
||||||
|
labelnames=["team_id", "team_alias"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remaining Budget for API Key
|
||||||
|
self.litellm_remaining_api_key_budget_metric = Gauge(
|
||||||
|
"litellm_remaining_api_key_budget_metric",
|
||||||
|
"Remaining budget for api key",
|
||||||
|
labelnames=["hashed_api_key", "api_key_alias"],
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
@ -51,7 +95,9 @@ class PrometheusLogger:
|
||||||
async def _async_log_event(
|
async def _async_log_event(
|
||||||
self, kwargs, response_obj, start_time, end_time, print_verbose, user_id
|
self, kwargs, response_obj, start_time, end_time, print_verbose, user_id
|
||||||
):
|
):
|
||||||
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
self.log_event(
|
||||||
|
kwargs, response_obj, start_time, end_time, user_id, print_verbose
|
||||||
|
)
|
||||||
|
|
||||||
def log_event(
|
def log_event(
|
||||||
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
|
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
|
||||||
|
@ -72,9 +118,36 @@ class PrometheusLogger:
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
)
|
)
|
||||||
user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None)
|
user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None)
|
||||||
|
user_api_key_alias = litellm_params.get("metadata", {}).get(
|
||||||
|
"user_api_key_alias", None
|
||||||
|
)
|
||||||
user_api_team = litellm_params.get("metadata", {}).get(
|
user_api_team = litellm_params.get("metadata", {}).get(
|
||||||
"user_api_key_team_id", None
|
"user_api_key_team_id", None
|
||||||
)
|
)
|
||||||
|
user_api_team_alias = litellm_params.get("metadata", {}).get(
|
||||||
|
"user_api_key_team_alias", None
|
||||||
|
)
|
||||||
|
|
||||||
|
_team_spend = litellm_params.get("metadata", {}).get(
|
||||||
|
"user_api_key_team_spend", None
|
||||||
|
)
|
||||||
|
_team_max_budget = litellm_params.get("metadata", {}).get(
|
||||||
|
"user_api_key_team_max_budget", None
|
||||||
|
)
|
||||||
|
_remaining_team_budget = safe_get_remaining_budget(
|
||||||
|
max_budget=_team_max_budget, spend=_team_spend
|
||||||
|
)
|
||||||
|
|
||||||
|
_api_key_spend = litellm_params.get("metadata", {}).get(
|
||||||
|
"user_api_key_spend", None
|
||||||
|
)
|
||||||
|
_api_key_max_budget = litellm_params.get("metadata", {}).get(
|
||||||
|
"user_api_key_max_budget", None
|
||||||
|
)
|
||||||
|
_remaining_api_key_budget = safe_get_remaining_budget(
|
||||||
|
max_budget=_api_key_max_budget, spend=_api_key_spend
|
||||||
|
)
|
||||||
|
|
||||||
if response_obj is not None:
|
if response_obj is not None:
|
||||||
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0)
|
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0)
|
||||||
else:
|
else:
|
||||||
|
@ -94,19 +167,47 @@ class PrometheusLogger:
|
||||||
user_api_key = hash_token(user_api_key)
|
user_api_key = hash_token(user_api_key)
|
||||||
|
|
||||||
self.litellm_requests_metric.labels(
|
self.litellm_requests_metric.labels(
|
||||||
end_user_id, user_api_key, model, user_api_team, user_id
|
end_user_id,
|
||||||
|
user_api_key,
|
||||||
|
model,
|
||||||
|
user_api_team,
|
||||||
|
user_api_team_alias,
|
||||||
|
user_id,
|
||||||
).inc()
|
).inc()
|
||||||
self.litellm_spend_metric.labels(
|
self.litellm_spend_metric.labels(
|
||||||
end_user_id, user_api_key, model, user_api_team, user_id
|
end_user_id,
|
||||||
|
user_api_key,
|
||||||
|
model,
|
||||||
|
user_api_team,
|
||||||
|
user_api_team_alias,
|
||||||
|
user_id,
|
||||||
).inc(response_cost)
|
).inc(response_cost)
|
||||||
self.litellm_tokens_metric.labels(
|
self.litellm_tokens_metric.labels(
|
||||||
end_user_id, user_api_key, model, user_api_team, user_id
|
end_user_id,
|
||||||
|
user_api_key,
|
||||||
|
model,
|
||||||
|
user_api_team,
|
||||||
|
user_api_team_alias,
|
||||||
|
user_id,
|
||||||
).inc(tokens_used)
|
).inc(tokens_used)
|
||||||
|
|
||||||
|
self.litellm_remaining_team_budget_metric.labels(
|
||||||
|
user_api_team, user_api_team_alias
|
||||||
|
).set(_remaining_team_budget)
|
||||||
|
|
||||||
|
self.litellm_remaining_api_key_budget_metric.labels(
|
||||||
|
user_api_key, user_api_key_alias
|
||||||
|
).set(_remaining_api_key_budget)
|
||||||
|
|
||||||
### FAILURE INCREMENT ###
|
### FAILURE INCREMENT ###
|
||||||
if "exception" in kwargs:
|
if "exception" in kwargs:
|
||||||
self.litellm_llm_api_failed_requests_metric.labels(
|
self.litellm_llm_api_failed_requests_metric.labels(
|
||||||
end_user_id, user_api_key, model, user_api_team, user_id
|
end_user_id,
|
||||||
|
user_api_key,
|
||||||
|
model,
|
||||||
|
user_api_team,
|
||||||
|
user_api_team_alias,
|
||||||
|
user_id,
|
||||||
).inc()
|
).inc()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
|
@ -114,3 +215,15 @@ class PrometheusLogger:
|
||||||
)
|
)
|
||||||
verbose_logger.debug(traceback.format_exc())
|
verbose_logger.debug(traceback.format_exc())
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def safe_get_remaining_budget(
|
||||||
|
max_budget: Optional[float], spend: Optional[float]
|
||||||
|
) -> float:
|
||||||
|
if max_budget is None:
|
||||||
|
return float("inf")
|
||||||
|
|
||||||
|
if spend is None:
|
||||||
|
return max_budget
|
||||||
|
|
||||||
|
return max_budget - spend
|
||||||
|
|
|
@ -330,6 +330,7 @@ class SlackAlerting(CustomLogger):
|
||||||
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
|
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
|
||||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||||
|
alerting_metadata: dict = {}
|
||||||
if time_difference_float > self.alerting_threshold:
|
if time_difference_float > self.alerting_threshold:
|
||||||
# add deployment latencies to alert
|
# add deployment latencies to alert
|
||||||
if (
|
if (
|
||||||
|
@ -337,7 +338,7 @@ class SlackAlerting(CustomLogger):
|
||||||
and "litellm_params" in kwargs
|
and "litellm_params" in kwargs
|
||||||
and "metadata" in kwargs["litellm_params"]
|
and "metadata" in kwargs["litellm_params"]
|
||||||
):
|
):
|
||||||
_metadata = kwargs["litellm_params"]["metadata"]
|
_metadata: dict = kwargs["litellm_params"]["metadata"]
|
||||||
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||||
request_info=request_info, metadata=_metadata
|
request_info=request_info, metadata=_metadata
|
||||||
)
|
)
|
||||||
|
@ -349,10 +350,14 @@ class SlackAlerting(CustomLogger):
|
||||||
request_info += (
|
request_info += (
|
||||||
f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
|
f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "alerting_metadata" in _metadata:
|
||||||
|
alerting_metadata = _metadata["alerting_metadata"]
|
||||||
await self.send_alert(
|
await self.send_alert(
|
||||||
message=slow_message + request_info,
|
message=slow_message + request_info,
|
||||||
level="Low",
|
level="Low",
|
||||||
alert_type="llm_too_slow",
|
alert_type="llm_too_slow",
|
||||||
|
alerting_metadata=alerting_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_update_daily_reports(
|
async def async_update_daily_reports(
|
||||||
|
@ -540,7 +545,12 @@ class SlackAlerting(CustomLogger):
|
||||||
message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s"
|
message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s"
|
||||||
|
|
||||||
# send alert
|
# send alert
|
||||||
await self.send_alert(message=message, level="Low", alert_type="daily_reports")
|
await self.send_alert(
|
||||||
|
message=message,
|
||||||
|
level="Low",
|
||||||
|
alert_type="daily_reports",
|
||||||
|
alerting_metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -582,6 +592,7 @@ class SlackAlerting(CustomLogger):
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
self.alerting_threshold
|
self.alerting_threshold
|
||||||
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
|
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
|
||||||
|
alerting_metadata: dict = {}
|
||||||
if (
|
if (
|
||||||
request_data is not None
|
request_data is not None
|
||||||
and request_data.get("litellm_status", "") != "success"
|
and request_data.get("litellm_status", "") != "success"
|
||||||
|
@ -606,7 +617,7 @@ class SlackAlerting(CustomLogger):
|
||||||
):
|
):
|
||||||
# In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
|
# In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
|
||||||
# in that case we fallback to the api base set in the request metadata
|
# in that case we fallback to the api base set in the request metadata
|
||||||
_metadata = request_data["metadata"]
|
_metadata: dict = request_data["metadata"]
|
||||||
_api_base = _metadata.get("api_base", "")
|
_api_base = _metadata.get("api_base", "")
|
||||||
|
|
||||||
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||||
|
@ -615,6 +626,9 @@ class SlackAlerting(CustomLogger):
|
||||||
|
|
||||||
if _api_base is None:
|
if _api_base is None:
|
||||||
_api_base = ""
|
_api_base = ""
|
||||||
|
|
||||||
|
if "alerting_metadata" in _metadata:
|
||||||
|
alerting_metadata = _metadata["alerting_metadata"]
|
||||||
request_info += f"\nAPI Base: `{_api_base}`"
|
request_info += f"\nAPI Base: `{_api_base}`"
|
||||||
# only alert hanging responses if they have not been marked as success
|
# only alert hanging responses if they have not been marked as success
|
||||||
alerting_message = (
|
alerting_message = (
|
||||||
|
@ -640,6 +654,7 @@ class SlackAlerting(CustomLogger):
|
||||||
message=alerting_message + request_info,
|
message=alerting_message + request_info,
|
||||||
level="Medium",
|
level="Medium",
|
||||||
alert_type="llm_requests_hanging",
|
alert_type="llm_requests_hanging",
|
||||||
|
alerting_metadata=alerting_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def failed_tracking_alert(self, error_message: str):
|
async def failed_tracking_alert(self, error_message: str):
|
||||||
|
@ -650,7 +665,10 @@ class SlackAlerting(CustomLogger):
|
||||||
result = await _cache.async_get_cache(key=_cache_key)
|
result = await _cache.async_get_cache(key=_cache_key)
|
||||||
if result is None:
|
if result is None:
|
||||||
await self.send_alert(
|
await self.send_alert(
|
||||||
message=message, level="High", alert_type="budget_alerts"
|
message=message,
|
||||||
|
level="High",
|
||||||
|
alert_type="budget_alerts",
|
||||||
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
await _cache.async_set_cache(
|
await _cache.async_set_cache(
|
||||||
key=_cache_key,
|
key=_cache_key,
|
||||||
|
@ -680,7 +698,7 @@ class SlackAlerting(CustomLogger):
|
||||||
return
|
return
|
||||||
if "budget_alerts" not in self.alert_types:
|
if "budget_alerts" not in self.alert_types:
|
||||||
return
|
return
|
||||||
_id: str = "default_id" # used for caching
|
_id: Optional[str] = "default_id" # used for caching
|
||||||
user_info_json = user_info.model_dump(exclude_none=True)
|
user_info_json = user_info.model_dump(exclude_none=True)
|
||||||
for k, v in user_info_json.items():
|
for k, v in user_info_json.items():
|
||||||
user_info_str = "\n{}: {}\n".format(k, v)
|
user_info_str = "\n{}: {}\n".format(k, v)
|
||||||
|
@ -751,6 +769,7 @@ class SlackAlerting(CustomLogger):
|
||||||
level="High",
|
level="High",
|
||||||
alert_type="budget_alerts",
|
alert_type="budget_alerts",
|
||||||
user_info=webhook_event,
|
user_info=webhook_event,
|
||||||
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
await _cache.async_set_cache(
|
await _cache.async_set_cache(
|
||||||
key=_cache_key,
|
key=_cache_key,
|
||||||
|
@ -769,7 +788,13 @@ class SlackAlerting(CustomLogger):
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
max_budget: Optional[float],
|
max_budget: Optional[float],
|
||||||
):
|
):
|
||||||
if end_user_id is not None and token is not None and response_cost is not None:
|
if (
|
||||||
|
self.alerting is not None
|
||||||
|
and "webhook" in self.alerting
|
||||||
|
and end_user_id is not None
|
||||||
|
and token is not None
|
||||||
|
and response_cost is not None
|
||||||
|
):
|
||||||
# log customer spend
|
# log customer spend
|
||||||
event = WebhookEvent(
|
event = WebhookEvent(
|
||||||
spend=response_cost,
|
spend=response_cost,
|
||||||
|
@ -941,7 +966,10 @@ class SlackAlerting(CustomLogger):
|
||||||
)
|
)
|
||||||
# send minor alert
|
# send minor alert
|
||||||
await self.send_alert(
|
await self.send_alert(
|
||||||
message=msg, level="Medium", alert_type="outage_alerts"
|
message=msg,
|
||||||
|
level="Medium",
|
||||||
|
alert_type="outage_alerts",
|
||||||
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
# set to true
|
# set to true
|
||||||
outage_value["minor_alert_sent"] = True
|
outage_value["minor_alert_sent"] = True
|
||||||
|
@ -963,7 +991,12 @@ class SlackAlerting(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
# send minor alert
|
# send minor alert
|
||||||
await self.send_alert(message=msg, level="High", alert_type="outage_alerts")
|
await self.send_alert(
|
||||||
|
message=msg,
|
||||||
|
level="High",
|
||||||
|
alert_type="outage_alerts",
|
||||||
|
alerting_metadata={},
|
||||||
|
)
|
||||||
# set to true
|
# set to true
|
||||||
outage_value["major_alert_sent"] = True
|
outage_value["major_alert_sent"] = True
|
||||||
|
|
||||||
|
@ -1062,7 +1095,10 @@ class SlackAlerting(CustomLogger):
|
||||||
)
|
)
|
||||||
# send minor alert
|
# send minor alert
|
||||||
await self.send_alert(
|
await self.send_alert(
|
||||||
message=msg, level="Medium", alert_type="outage_alerts"
|
message=msg,
|
||||||
|
level="Medium",
|
||||||
|
alert_type="outage_alerts",
|
||||||
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
# set to true
|
# set to true
|
||||||
outage_value["minor_alert_sent"] = True
|
outage_value["minor_alert_sent"] = True
|
||||||
|
@ -1081,7 +1117,10 @@ class SlackAlerting(CustomLogger):
|
||||||
)
|
)
|
||||||
# send minor alert
|
# send minor alert
|
||||||
await self.send_alert(
|
await self.send_alert(
|
||||||
message=msg, level="High", alert_type="outage_alerts"
|
message=msg,
|
||||||
|
level="High",
|
||||||
|
alert_type="outage_alerts",
|
||||||
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
# set to true
|
# set to true
|
||||||
outage_value["major_alert_sent"] = True
|
outage_value["major_alert_sent"] = True
|
||||||
|
@ -1143,7 +1182,10 @@ Model Info:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
alert_val = self.send_alert(
|
alert_val = self.send_alert(
|
||||||
message=message, level="Low", alert_type="new_model_added"
|
message=message,
|
||||||
|
level="Low",
|
||||||
|
alert_type="new_model_added",
|
||||||
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
|
|
||||||
if alert_val is not None and asyncio.iscoroutine(alert_val):
|
if alert_val is not None and asyncio.iscoroutine(alert_val):
|
||||||
|
@ -1159,6 +1201,9 @@ Model Info:
|
||||||
Currently only implemented for budget alerts
|
Currently only implemented for budget alerts
|
||||||
|
|
||||||
Returns -> True if sent, False if not.
|
Returns -> True if sent, False if not.
|
||||||
|
|
||||||
|
Raises Exception
|
||||||
|
- if WEBHOOK_URL is not set
|
||||||
"""
|
"""
|
||||||
|
|
||||||
webhook_url = os.getenv("WEBHOOK_URL", None)
|
webhook_url = os.getenv("WEBHOOK_URL", None)
|
||||||
|
@ -1297,7 +1342,9 @@ Model Info:
|
||||||
verbose_proxy_logger.error("Error sending email alert %s", str(e))
|
verbose_proxy_logger.error("Error sending email alert %s", str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def send_email_alert_using_smtp(self, webhook_event: WebhookEvent) -> bool:
|
async def send_email_alert_using_smtp(
|
||||||
|
self, webhook_event: WebhookEvent, alert_type: str
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Sends structured Email alert to an SMTP server
|
Sends structured Email alert to an SMTP server
|
||||||
|
|
||||||
|
@ -1306,7 +1353,6 @@ Model Info:
|
||||||
Returns -> True if sent, False if not.
|
Returns -> True if sent, False if not.
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.utils import send_email
|
from litellm.proxy.utils import send_email
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import premium_user, prisma_client
|
from litellm.proxy.proxy_server import premium_user, prisma_client
|
||||||
|
|
||||||
email_logo_url = os.getenv(
|
email_logo_url = os.getenv(
|
||||||
|
@ -1360,6 +1406,10 @@ Model Info:
|
||||||
subject=email_event["subject"],
|
subject=email_event["subject"],
|
||||||
html=email_event["html"],
|
html=email_event["html"],
|
||||||
)
|
)
|
||||||
|
if webhook_event.event_group == "team":
|
||||||
|
from litellm.integrations.email_alerting import send_team_budget_alert
|
||||||
|
|
||||||
|
await send_team_budget_alert(webhook_event=webhook_event)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -1368,6 +1418,7 @@ Model Info:
|
||||||
message: str,
|
message: str,
|
||||||
level: Literal["Low", "Medium", "High"],
|
level: Literal["Low", "Medium", "High"],
|
||||||
alert_type: Literal[AlertType],
|
alert_type: Literal[AlertType],
|
||||||
|
alerting_metadata: dict,
|
||||||
user_info: Optional[WebhookEvent] = None,
|
user_info: Optional[WebhookEvent] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -1401,7 +1452,9 @@ Model Info:
|
||||||
and user_info is not None
|
and user_info is not None
|
||||||
):
|
):
|
||||||
# only send budget alerts over Email
|
# only send budget alerts over Email
|
||||||
await self.send_email_alert_using_smtp(webhook_event=user_info)
|
await self.send_email_alert_using_smtp(
|
||||||
|
webhook_event=user_info, alert_type=alert_type
|
||||||
|
)
|
||||||
|
|
||||||
if "slack" not in self.alerting:
|
if "slack" not in self.alerting:
|
||||||
return
|
return
|
||||||
|
@ -1425,6 +1478,9 @@ Model Info:
|
||||||
if kwargs:
|
if kwargs:
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
formatted_message += f"\n\n{key}: `{value}`\n\n"
|
formatted_message += f"\n\n{key}: `{value}`\n\n"
|
||||||
|
if alerting_metadata:
|
||||||
|
for key, value in alerting_metadata.items():
|
||||||
|
formatted_message += f"\n\n*Alerting Metadata*: \n{key}: `{value}`\n\n"
|
||||||
if _proxy_base_url is not None:
|
if _proxy_base_url is not None:
|
||||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||||
|
|
||||||
|
@ -1440,7 +1496,7 @@ Model Info:
|
||||||
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
||||||
|
|
||||||
if slack_webhook_url is None:
|
if slack_webhook_url is None:
|
||||||
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
|
raise ValueError("Missing SLACK_WEBHOOK_URL from environment")
|
||||||
payload = {"text": formatted_message}
|
payload = {"text": formatted_message}
|
||||||
headers = {"Content-type": "application/json"}
|
headers = {"Content-type": "application/json"}
|
||||||
|
|
||||||
|
@ -1453,7 +1509,7 @@ Model Info:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Error sending slack alert. Error=", response.text
|
"Error sending slack alert. Error={}".format(response.text)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
@ -1622,6 +1678,7 @@ Model Info:
|
||||||
message=_weekly_spend_message,
|
message=_weekly_spend_message,
|
||||||
level="Low",
|
level="Low",
|
||||||
alert_type="spend_reports",
|
alert_type="spend_reports",
|
||||||
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
||||||
|
@ -1673,6 +1730,7 @@ Model Info:
|
||||||
message=_spend_message,
|
message=_spend_message,
|
||||||
level="Low",
|
level="Low",
|
||||||
alert_type="spend_reports",
|
alert_type="spend_reports",
|
||||||
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
||||||
|
|
41
litellm/litellm_core_utils/core_helpers.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
# What is this?
|
||||||
|
## Helper utilities for the model response objects
|
||||||
|
|
||||||
|
|
||||||
|
def map_finish_reason(
|
||||||
|
finish_reason: str,
|
||||||
|
): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null'
|
||||||
|
# anthropic mapping
|
||||||
|
if finish_reason == "stop_sequence":
|
||||||
|
return "stop"
|
||||||
|
# cohere mapping - https://docs.cohere.com/reference/generate
|
||||||
|
elif finish_reason == "COMPLETE":
|
||||||
|
return "stop"
|
||||||
|
elif finish_reason == "MAX_TOKENS": # cohere + vertex ai
|
||||||
|
return "length"
|
||||||
|
elif finish_reason == "ERROR_TOXIC":
|
||||||
|
return "content_filter"
|
||||||
|
elif (
|
||||||
|
finish_reason == "ERROR"
|
||||||
|
): # openai currently doesn't support an 'error' finish reason
|
||||||
|
return "stop"
|
||||||
|
# huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream
|
||||||
|
elif finish_reason == "eos_token" or finish_reason == "stop_sequence":
|
||||||
|
return "stop"
|
||||||
|
elif (
|
||||||
|
finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP"
|
||||||
|
): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',]
|
||||||
|
return "stop"
|
||||||
|
elif finish_reason == "SAFETY": # vertex ai
|
||||||
|
return "content_filter"
|
||||||
|
elif finish_reason == "STOP": # vertex ai
|
||||||
|
return "stop"
|
||||||
|
elif finish_reason == "end_turn" or finish_reason == "stop_sequence": # anthropic
|
||||||
|
return "stop"
|
||||||
|
elif finish_reason == "max_tokens": # anthropic
|
||||||
|
return "length"
|
||||||
|
elif finish_reason == "tool_use": # anthropic
|
||||||
|
return "tool_calls"
|
||||||
|
elif finish_reason == "content_filtered":
|
||||||
|
return "content_filter"
|
||||||
|
return finish_reason
|
1825
litellm/litellm_core_utils/litellm_logging.py
Normal file
82
litellm/litellm_core_utils/llm_cost_calc/google.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# What is this?
|
||||||
|
## Cost calculation for Google AI Studio / Vertex AI models
|
||||||
|
from typing import Literal, Tuple
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
"""
|
||||||
|
Gemini pricing covers:
|
||||||
|
- token
|
||||||
|
- image
|
||||||
|
- audio
|
||||||
|
- video
|
||||||
|
"""
|
||||||
|
|
||||||
|
models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro"]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_above_128k(tokens: float) -> bool:
|
||||||
|
if tokens > 128000:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def cost_per_token(
|
||||||
|
model: str,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
prompt_tokens: float,
|
||||||
|
completion_tokens: float,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- model: str, the model name without provider prefix
|
||||||
|
- custom_llm_provider: str, either "vertex_ai-*" or "gemini"
|
||||||
|
- prompt_tokens: float, the number of input tokens
|
||||||
|
- completion_tokens: float, the number of output tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception if model requires >128k pricing, but model cost not mapped
|
||||||
|
"""
|
||||||
|
## GET MODEL INFO
|
||||||
|
model_info = litellm.get_model_info(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
## CALCULATE INPUT COST
|
||||||
|
if (
|
||||||
|
_is_above_128k(tokens=prompt_tokens)
|
||||||
|
and model not in models_without_dynamic_pricing
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
model_info["input_cost_per_token_above_128k_tokens"] is not None
|
||||||
|
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
|
||||||
|
model, model_info
|
||||||
|
)
|
||||||
|
prompt_cost = (
|
||||||
|
prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_cost = prompt_tokens * model_info["input_cost_per_token"]
|
||||||
|
|
||||||
|
## CALCULATE OUTPUT COST
|
||||||
|
if (
|
||||||
|
_is_above_128k(tokens=completion_tokens)
|
||||||
|
and model not in models_without_dynamic_pricing
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
model_info["output_cost_per_token_above_128k_tokens"] is not None
|
||||||
|
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
|
||||||
|
model, model_info
|
||||||
|
)
|
||||||
|
completion_cost = (
|
||||||
|
completion_tokens * model_info["output_cost_per_token_above_128k_tokens"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion_cost = completion_tokens * model_info["output_cost_per_token"]
|
||||||
|
|
||||||
|
return prompt_cost, completion_cost
|
28
litellm/litellm_core_utils/llm_request_utils.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
Ensure that the extra_body sent in the request is safe, otherwise users will see this error
|
||||||
|
|
||||||
|
"Object of type TextPromptClient is not JSON serializable
|
||||||
|
|
||||||
|
|
||||||
|
Relevant Issue: https://github.com/BerriAI/litellm/issues/4140
|
||||||
|
"""
|
||||||
|
if extra_body is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(extra_body, dict):
|
||||||
|
return extra_body
|
||||||
|
|
||||||
|
if "metadata" in extra_body and isinstance(extra_body["metadata"], dict):
|
||||||
|
if "prompt" in extra_body["metadata"]:
|
||||||
|
_prompt = extra_body["metadata"].get("prompt")
|
||||||
|
|
||||||
|
# users can send Langfuse TextPromptClient objects, so we need to convert them to dicts
|
||||||
|
# Langfuse TextPromptClients have .__dict__ attribute
|
||||||
|
if _prompt is not None and hasattr(_prompt, "__dict__"):
|
||||||
|
extra_body["metadata"]["prompt"] = _prompt.__dict__
|
||||||
|
|
||||||
|
return extra_body
|
71
litellm/litellm_core_utils/redact_messages.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
# +-----------------------------------------------+
|
||||||
|
# | |
|
||||||
|
# | Give Feedback / Get Help |
|
||||||
|
# | https://github.com/BerriAI/litellm/issues/new |
|
||||||
|
# | |
|
||||||
|
# +-----------------------------------------------+
|
||||||
|
#
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
|
Logging as _LiteLLMLoggingObject,
|
||||||
|
)
|
||||||
|
|
||||||
|
LiteLLMLoggingObject = _LiteLLMLoggingObject
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObject = Any
|
||||||
|
|
||||||
|
|
||||||
|
def redact_message_input_output_from_logging(
|
||||||
|
litellm_logging_obj: LiteLLMLoggingObject, result
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Removes messages, prompts, input, response from logging. This modifies the data in-place
|
||||||
|
only redacts when litellm.turn_off_message_logging == True
|
||||||
|
"""
|
||||||
|
# check if user opted out of logging message/response to callbacks
|
||||||
|
if litellm.turn_off_message_logging is not True:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# remove messages, prompts, input, response from logging
|
||||||
|
litellm_logging_obj.model_call_details["messages"] = [
|
||||||
|
{"role": "user", "content": "redacted-by-litellm"}
|
||||||
|
]
|
||||||
|
litellm_logging_obj.model_call_details["prompt"] = ""
|
||||||
|
litellm_logging_obj.model_call_details["input"] = ""
|
||||||
|
|
||||||
|
# response cleaning
|
||||||
|
# ChatCompletion Responses
|
||||||
|
if (
|
||||||
|
litellm_logging_obj.stream is True
|
||||||
|
and "complete_streaming_response" in litellm_logging_obj.model_call_details
|
||||||
|
):
|
||||||
|
_streaming_response = litellm_logging_obj.model_call_details[
|
||||||
|
"complete_streaming_response"
|
||||||
|
]
|
||||||
|
for choice in _streaming_response.choices:
|
||||||
|
if isinstance(choice, litellm.Choices):
|
||||||
|
choice.message.content = "redacted-by-litellm"
|
||||||
|
elif isinstance(choice, litellm.utils.StreamingChoices):
|
||||||
|
choice.delta.content = "redacted-by-litellm"
|
||||||
|
else:
|
||||||
|
if result is not None:
|
||||||
|
if isinstance(result, litellm.ModelResponse):
|
||||||
|
# only deep copy litellm.ModelResponse
|
||||||
|
_result = copy.deepcopy(result)
|
||||||
|
if hasattr(_result, "choices") and _result.choices is not None:
|
||||||
|
for choice in _result.choices:
|
||||||
|
if isinstance(choice, litellm.Choices):
|
||||||
|
choice.message.content = "redacted-by-litellm"
|
||||||
|
elif isinstance(choice, litellm.utils.StreamingChoices):
|
||||||
|
choice.delta.content = "redacted-by-litellm"
|
||||||
|
|
||||||
|
return _result
|
||||||
|
|
||||||
|
# by default return result
|
||||||
|
return result
|
|
@ -5,10 +5,16 @@ import requests, copy # type: ignore
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Optional, List, Union
|
from typing import Callable, Optional, List, Union
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
import litellm.litellm_core_utils
|
||||||
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
_get_async_httpx_client,
|
||||||
|
_get_httpx_client,
|
||||||
|
)
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
|
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
|
||||||
|
@ -171,7 +177,7 @@ async def make_call(
|
||||||
logging_obj,
|
logging_obj,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler() # Create a new client if none provided
|
client = _get_async_httpx_client() # Create a new client if none provided
|
||||||
|
|
||||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
@ -201,7 +207,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: litellm.utils.Logging,
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: Union[dict, str],
|
data: Union[dict, str],
|
||||||
|
@ -316,7 +322,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: litellm.utils.Logging,
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: Union[dict, str],
|
data: Union[dict, str],
|
||||||
|
@ -463,9 +469,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
async_handler = AsyncHTTPHandler(
|
async_handler = _get_async_httpx_client()
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
|
||||||
)
|
|
||||||
response = await async_handler.post(api_base, headers=headers, json=data)
|
response = await async_handler.post(api_base, headers=headers, json=data)
|
||||||
if stream and _is_function_call:
|
if stream and _is_function_call:
|
||||||
return self.process_streaming_response(
|
return self.process_streaming_response(
|
||||||
|
|
|
@ -1,41 +1,58 @@
|
||||||
from typing import Optional, Union, Any, Literal, Coroutine, Iterable
|
import asyncio
|
||||||
from typing_extensions import overload
|
import json
|
||||||
import types, requests
|
|
||||||
from .base import BaseLLM
|
|
||||||
from litellm.utils import (
|
|
||||||
ModelResponse,
|
|
||||||
Choices,
|
|
||||||
Message,
|
|
||||||
CustomStreamWrapper,
|
|
||||||
convert_to_model_response_object,
|
|
||||||
TranscriptionResponse,
|
|
||||||
get_secret,
|
|
||||||
UnsupportedParamsError,
|
|
||||||
)
|
|
||||||
from typing import Callable, Optional, BinaryIO, List
|
|
||||||
from litellm import OpenAIConfig
|
|
||||||
import litellm, json
|
|
||||||
import httpx # type: ignore
|
|
||||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
|
||||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
|
||||||
import uuid
|
|
||||||
import os
|
import os
|
||||||
|
import types
|
||||||
|
import uuid
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
BinaryIO,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import requests
|
||||||
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
from typing_extensions import overload
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import OpenAIConfig
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.utils import (
|
||||||
|
Choices,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
UnsupportedParamsError,
|
||||||
|
convert_to_model_response_object,
|
||||||
|
get_secret,
|
||||||
|
)
|
||||||
|
|
||||||
from ..types.llms.openai import (
|
from ..types.llms.openai import (
|
||||||
AsyncCursorPage,
|
|
||||||
AssistantToolParam,
|
|
||||||
SyncCursorPage,
|
|
||||||
Assistant,
|
Assistant,
|
||||||
MessageData,
|
|
||||||
OpenAIMessage,
|
|
||||||
OpenAICreateThreadParamsMessage,
|
|
||||||
Thread,
|
|
||||||
AssistantToolParam,
|
|
||||||
Run,
|
|
||||||
AssistantEventHandler,
|
AssistantEventHandler,
|
||||||
|
AssistantStreamManager,
|
||||||
|
AssistantToolParam,
|
||||||
AsyncAssistantEventHandler,
|
AsyncAssistantEventHandler,
|
||||||
AsyncAssistantStreamManager,
|
AsyncAssistantStreamManager,
|
||||||
AssistantStreamManager,
|
AsyncCursorPage,
|
||||||
|
MessageData,
|
||||||
|
OpenAICreateThreadParamsMessage,
|
||||||
|
OpenAIMessage,
|
||||||
|
Run,
|
||||||
|
SyncCursorPage,
|
||||||
|
Thread,
|
||||||
)
|
)
|
||||||
|
from .base import BaseLLM
|
||||||
|
from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport
|
||||||
|
|
||||||
|
azure_ad_cache = DualCache()
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
class AzureOpenAIError(Exception):
|
||||||
|
@ -309,9 +326,12 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
|
|
||||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||||
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
|
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||||
|
azure_authority_host = os.getenv(
|
||||||
|
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||||
|
)
|
||||||
|
|
||||||
if azure_client_id is None or azure_tenant is None:
|
if azure_client_id is None or azure_tenant_id is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||||
|
@ -325,8 +345,21 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
azure_ad_token_cache_key = json.dumps(
|
||||||
|
{
|
||||||
|
"azure_client_id": azure_client_id,
|
||||||
|
"azure_tenant_id": azure_tenant_id,
|
||||||
|
"azure_authority_host": azure_authority_host,
|
||||||
|
"oidc_token": oidc_token,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||||
|
if azure_ad_token_access_token is not None:
|
||||||
|
return azure_ad_token_access_token
|
||||||
|
|
||||||
req_token = httpx.post(
|
req_token = httpx.post(
|
||||||
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
|
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
|
||||||
data={
|
data={
|
||||||
"client_id": azure_client_id,
|
"client_id": azure_client_id,
|
||||||
"grant_type": "client_credentials",
|
"grant_type": "client_credentials",
|
||||||
|
@ -342,12 +375,27 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
message=req_token.text,
|
message=req_token.text,
|
||||||
)
|
)
|
||||||
|
|
||||||
possible_azure_ad_token = req_token.json().get("access_token", None)
|
azure_ad_token_json = req_token.json()
|
||||||
|
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
|
||||||
|
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
|
||||||
|
|
||||||
if possible_azure_ad_token is None:
|
if azure_ad_token_access_token is None:
|
||||||
raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned")
|
raise AzureOpenAIError(
|
||||||
|
status_code=422, message="Azure AD Token access_token not returned"
|
||||||
|
)
|
||||||
|
|
||||||
return possible_azure_ad_token
|
if azure_ad_token_expires_in is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=422, message="Azure AD Token expires_in not returned"
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_ad_cache.set_cache(
|
||||||
|
key=azure_ad_token_cache_key,
|
||||||
|
value=azure_ad_token_access_token,
|
||||||
|
ttl=azure_ad_token_expires_in,
|
||||||
|
)
|
||||||
|
|
||||||
|
return azure_ad_token_access_token
|
||||||
|
|
||||||
|
|
||||||
class AzureChatCompletion(BaseLLM):
|
class AzureChatCompletion(BaseLLM):
|
||||||
|
@ -619,6 +667,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
if hasattr(e, "status_code"):
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import litellm
|
import litellm
|
||||||
import httpx, requests
|
import httpx, requests
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from litellm.utils import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM:
|
class BaseLLM:
|
||||||
|
@ -27,6 +27,25 @@ class BaseLLM:
|
||||||
"""
|
"""
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def process_text_completion_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: litellm.utils.TextCompletionResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]:
|
||||||
|
"""
|
||||||
|
Helper function to process the response across sync + async completion calls
|
||||||
|
"""
|
||||||
|
return model_response
|
||||||
|
|
||||||
def create_client_session(self):
|
def create_client_session(self):
|
||||||
if litellm.client_session:
|
if litellm.client_session:
|
||||||
_client_session = litellm.client_session
|
_client_session = litellm.client_session
|
||||||
|
|
|
@ -1,25 +1,27 @@
|
||||||
import json, copy, types
|
import copy
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time, uuid
|
from typing import Any, Callable, List, Optional, Union
|
||||||
from typing import Callable, Optional, Any, Union, List
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import (
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
ModelResponse,
|
from litellm.types.utils import ImageResponse, ModelResponse, Usage
|
||||||
get_secret,
|
from litellm.utils import get_secret
|
||||||
Usage,
|
|
||||||
ImageResponse,
|
|
||||||
map_finish_reason,
|
|
||||||
)
|
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
prompt_factory,
|
|
||||||
custom_prompt,
|
|
||||||
construct_tool_use_system_prompt,
|
construct_tool_use_system_prompt,
|
||||||
|
contains_tag,
|
||||||
|
custom_prompt,
|
||||||
extract_between_tags,
|
extract_between_tags,
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
contains_tag,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
class BedrockError(Exception):
|
class BedrockError(Exception):
|
||||||
|
@ -633,7 +635,11 @@ def init_bedrock_client(
|
||||||
config = boto3.session.Config()
|
config = boto3.session.Config()
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
if (
|
||||||
|
aws_web_identity_token is not None
|
||||||
|
and aws_role_name is not None
|
||||||
|
and aws_session_name is not None
|
||||||
|
):
|
||||||
oidc_token = get_secret(aws_web_identity_token)
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
if oidc_token is None:
|
if oidc_token is None:
|
||||||
|
@ -642,9 +648,7 @@ def init_bedrock_client(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
)
|
)
|
||||||
|
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client("sts")
|
||||||
"sts"
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
@ -726,38 +730,31 @@ def init_bedrock_client(
|
||||||
|
|
||||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
# handle anthropic prompts and amazon titan prompts
|
# handle anthropic prompts and amazon titan prompts
|
||||||
if provider == "anthropic" or provider == "amazon":
|
chat_template_provider = ["anthropic", "amazon", "mistral", "meta"]
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
role_dict=model_prompt_details["roles"],
|
role_dict=model_prompt_details["roles"],
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if provider in chat_template_provider:
|
||||||
prompt = prompt_factory(
|
prompt = prompt_factory(
|
||||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
elif provider == "mistral":
|
else:
|
||||||
prompt = prompt_factory(
|
prompt = ""
|
||||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
for message in messages:
|
||||||
)
|
if "role" in message:
|
||||||
elif provider == "meta":
|
if message["role"] == "user":
|
||||||
prompt = prompt_factory(
|
prompt += f"{message['content']}"
|
||||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
else:
|
||||||
)
|
prompt += f"{message['content']}"
|
||||||
else:
|
|
||||||
prompt = ""
|
|
||||||
for message in messages:
|
|
||||||
if "role" in message:
|
|
||||||
if message["role"] == "user":
|
|
||||||
prompt += f"{message['content']}"
|
|
||||||
else:
|
else:
|
||||||
prompt += f"{message['content']}"
|
prompt += f"{message['content']}"
|
||||||
else:
|
|
||||||
prompt += f"{message['content']}"
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,13 +22,12 @@ from typing import (
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
Usage,
|
Usage,
|
||||||
map_finish_reason,
|
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
Message,
|
|
||||||
Choices,
|
|
||||||
get_secret,
|
get_secret,
|
||||||
Logging,
|
|
||||||
)
|
)
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
from litellm.types.utils import Message, Choices
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
@ -41,7 +40,12 @@ from .prompt_templates.factory import (
|
||||||
_bedrock_converse_messages_pt,
|
_bedrock_converse_messages_pt,
|
||||||
_bedrock_tools_pt,
|
_bedrock_tools_pt,
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_async_httpx_client,
|
||||||
|
_get_httpx_client,
|
||||||
|
)
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||||
|
@ -51,7 +55,11 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionDeltaChunk,
|
||||||
)
|
)
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
iam_cache = DualCache()
|
||||||
|
|
||||||
|
|
||||||
class AmazonCohereChatConfig:
|
class AmazonCohereChatConfig:
|
||||||
|
@ -164,7 +172,7 @@ async def make_call(
|
||||||
logging_obj,
|
logging_obj,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler() # Create a new client if none provided
|
client = _get_async_httpx_client() # Create a new client if none provided
|
||||||
|
|
||||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
@ -195,7 +203,7 @@ def make_sync_call(
|
||||||
logging_obj,
|
logging_obj,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = HTTPHandler() # Create a new client if none provided
|
client = _get_httpx_client() # Create a new client if none provided
|
||||||
|
|
||||||
response = client.post(api_base, headers=headers, data=data, stream=True)
|
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
@ -329,33 +337,60 @@ class BedrockLLM(BaseLLM):
|
||||||
and aws_role_name is not None
|
and aws_role_name is not None
|
||||||
and aws_session_name is not None
|
and aws_session_name is not None
|
||||||
):
|
):
|
||||||
oidc_token = get_secret(aws_web_identity_token)
|
iam_creds_cache_key = json.dumps(
|
||||||
|
{
|
||||||
|
"aws_web_identity_token": aws_web_identity_token,
|
||||||
|
"aws_role_name": aws_role_name,
|
||||||
|
"aws_session_name": aws_session_name,
|
||||||
|
"aws_region_name": aws_region_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if oidc_token is None:
|
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
|
||||||
raise BedrockError(
|
if iam_creds_dict is None:
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
status_code=401,
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com",
|
||||||
)
|
)
|
||||||
|
|
||||||
sts_client = boto3.client("sts")
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
iam_creds_dict = {
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||||
sts_response = sts_client.assume_role_with_web_identity(
|
"aws_secret_access_key": sts_response["Credentials"][
|
||||||
RoleArn=aws_role_name,
|
"SecretAccessKey"
|
||||||
RoleSessionName=aws_session_name,
|
],
|
||||||
WebIdentityToken=oidc_token,
|
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||||
DurationSeconds=3600,
|
"region_name": aws_region_name,
|
||||||
)
|
}
|
||||||
|
|
||||||
session = boto3.Session(
|
iam_cache.set_cache(
|
||||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
key=iam_creds_cache_key,
|
||||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
value=json.dumps(iam_creds_dict),
|
||||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
ttl=3600 - 60,
|
||||||
region_name=aws_region_name,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return session.get_credentials()
|
session = boto3.Session(**iam_creds_dict)
|
||||||
|
|
||||||
|
iam_creds = session.get_credentials()
|
||||||
|
|
||||||
|
return iam_creds
|
||||||
elif aws_role_name is not None and aws_session_name is not None:
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
|
@ -958,7 +993,7 @@ class BedrockLLM(BaseLLM):
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
timeout = httpx.Timeout(timeout)
|
timeout = httpx.Timeout(timeout)
|
||||||
_params["timeout"] = timeout
|
_params["timeout"] = timeout
|
||||||
self.client = HTTPHandler(**_params) # type: ignore
|
self.client = _get_httpx_client(_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.client = client
|
self.client = client
|
||||||
if (stream is not None and stream == True) and provider != "ai21":
|
if (stream is not None and stream == True) and provider != "ai21":
|
||||||
|
@ -1040,7 +1075,7 @@ class BedrockLLM(BaseLLM):
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
timeout = httpx.Timeout(timeout)
|
timeout = httpx.Timeout(timeout)
|
||||||
_params["timeout"] = timeout
|
_params["timeout"] = timeout
|
||||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
client = _get_async_httpx_client(_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
client = client # type: ignore
|
client = client # type: ignore
|
||||||
|
|
||||||
|
@ -1420,33 +1455,60 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
and aws_role_name is not None
|
and aws_role_name is not None
|
||||||
and aws_session_name is not None
|
and aws_session_name is not None
|
||||||
):
|
):
|
||||||
oidc_token = get_secret(aws_web_identity_token)
|
iam_creds_cache_key = json.dumps(
|
||||||
|
{
|
||||||
|
"aws_web_identity_token": aws_web_identity_token,
|
||||||
|
"aws_role_name": aws_role_name,
|
||||||
|
"aws_session_name": aws_session_name,
|
||||||
|
"aws_region_name": aws_region_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if oidc_token is None:
|
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
|
||||||
raise BedrockError(
|
if iam_creds_dict is None:
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
status_code=401,
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com",
|
||||||
)
|
)
|
||||||
|
|
||||||
sts_client = boto3.client("sts")
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
iam_creds_dict = {
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||||
sts_response = sts_client.assume_role_with_web_identity(
|
"aws_secret_access_key": sts_response["Credentials"][
|
||||||
RoleArn=aws_role_name,
|
"SecretAccessKey"
|
||||||
RoleSessionName=aws_session_name,
|
],
|
||||||
WebIdentityToken=oidc_token,
|
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||||
DurationSeconds=3600,
|
"region_name": aws_region_name,
|
||||||
)
|
}
|
||||||
|
|
||||||
session = boto3.Session(
|
iam_cache.set_cache(
|
||||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
key=iam_creds_cache_key,
|
||||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
value=json.dumps(iam_creds_dict),
|
||||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
ttl=3600 - 60,
|
||||||
region_name=aws_region_name,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return session.get_credentials()
|
session = boto3.Session(**iam_creds_dict)
|
||||||
|
|
||||||
|
iam_creds = session.get_credentials()
|
||||||
|
|
||||||
|
return iam_creds
|
||||||
elif aws_role_name is not None and aws_session_name is not None:
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
|
@ -1542,7 +1604,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
timeout = httpx.Timeout(timeout)
|
timeout = httpx.Timeout(timeout)
|
||||||
_params["timeout"] = timeout
|
_params["timeout"] = timeout
|
||||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
client = _get_async_httpx_client(_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
client = client # type: ignore
|
client = client # type: ignore
|
||||||
|
|
||||||
|
@ -1814,7 +1876,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
timeout = httpx.Timeout(timeout)
|
timeout = httpx.Timeout(timeout)
|
||||||
_params["timeout"] = timeout
|
_params["timeout"] = timeout
|
||||||
client = HTTPHandler(**_params) # type: ignore
|
client = _get_httpx_client(_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
client = client
|
client = client
|
||||||
try:
|
try:
|
||||||
|
@ -1859,29 +1921,59 @@ class AWSEventStreamDecoder:
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
text = ""
|
try:
|
||||||
tool_str = ""
|
text = ""
|
||||||
is_finished = False
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
finish_reason = ""
|
is_finished = False
|
||||||
usage: Optional[ConverseTokenUsageBlock] = None
|
finish_reason = ""
|
||||||
if "delta" in chunk_data:
|
usage: Optional[ConverseTokenUsageBlock] = None
|
||||||
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
|
||||||
if "text" in delta_obj:
|
index = int(chunk_data.get("contentBlockIndex", 0))
|
||||||
text = delta_obj["text"]
|
if "start" in chunk_data:
|
||||||
elif "toolUse" in delta_obj:
|
start_obj = ContentBlockStartEvent(**chunk_data["start"])
|
||||||
tool_str = delta_obj["toolUse"]["input"]
|
if (
|
||||||
elif "stopReason" in chunk_data:
|
start_obj is not None
|
||||||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
and "toolUse" in start_obj
|
||||||
elif "usage" in chunk_data:
|
and start_obj["toolUse"] is not None
|
||||||
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
):
|
||||||
response = GenericStreamingChunk(
|
tool_use = {
|
||||||
text=text,
|
"id": start_obj["toolUse"]["toolUseId"],
|
||||||
tool_str=tool_str,
|
"type": "function",
|
||||||
is_finished=is_finished,
|
"function": {
|
||||||
finish_reason=finish_reason,
|
"name": start_obj["toolUse"]["name"],
|
||||||
usage=usage,
|
"arguments": "",
|
||||||
)
|
},
|
||||||
return response
|
}
|
||||||
|
elif "delta" in chunk_data:
|
||||||
|
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
||||||
|
if "text" in delta_obj:
|
||||||
|
text = delta_obj["text"]
|
||||||
|
elif "toolUse" in delta_obj:
|
||||||
|
tool_use = {
|
||||||
|
"id": None,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": None,
|
||||||
|
"arguments": delta_obj["toolUse"]["input"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif "stopReason" in chunk_data:
|
||||||
|
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||||
|
is_finished = True
|
||||||
|
elif "usage" in chunk_data:
|
||||||
|
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
||||||
|
|
||||||
|
response = GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Received streaming error - {}".format(str(e)))
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -1890,12 +1982,16 @@ class AWSEventStreamDecoder:
|
||||||
if "outputText" in chunk_data:
|
if "outputText" in chunk_data:
|
||||||
text = chunk_data["outputText"]
|
text = chunk_data["outputText"]
|
||||||
# ai21 mapping
|
# ai21 mapping
|
||||||
if "ai21" in self.model: # fake ai21 streaming
|
elif "ai21" in self.model: # fake ai21 streaming
|
||||||
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
######## bedrock.anthropic mappings ###############
|
######## bedrock.anthropic mappings ###############
|
||||||
elif "delta" in chunk_data:
|
elif (
|
||||||
|
"contentBlockIndex" in chunk_data
|
||||||
|
or "stopReason" in chunk_data
|
||||||
|
or "metrics" in chunk_data
|
||||||
|
):
|
||||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||||
######## bedrock.mistral mappings ###############
|
######## bedrock.mistral mappings ###############
|
||||||
elif "outputs" in chunk_data:
|
elif "outputs" in chunk_data:
|
||||||
|
@ -1905,7 +2001,7 @@ class AWSEventStreamDecoder:
|
||||||
):
|
):
|
||||||
text = chunk_data["outputs"][0]["text"]
|
text = chunk_data["outputs"][0]["text"]
|
||||||
stop_reason = chunk_data.get("stop_reason", None)
|
stop_reason = chunk_data.get("stop_reason", None)
|
||||||
if stop_reason != None:
|
if stop_reason is not None:
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = stop_reason
|
finish_reason = stop_reason
|
||||||
######## bedrock.cohere mappings ###############
|
######## bedrock.cohere mappings ###############
|
||||||
|
@ -1926,8 +2022,9 @@ class AWSEventStreamDecoder:
|
||||||
text=text,
|
text=text,
|
||||||
is_finished=is_finished,
|
is_finished=is_finished,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
tool_str="",
|
|
||||||
usage=None,
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||||
|
|
|
@ -139,6 +139,7 @@ def process_response(
|
||||||
|
|
||||||
def convert_model_to_url(model: str, api_base: str):
|
def convert_model_to_url(model: str, api_base: str):
|
||||||
user_id, app_id, model_id = model.split(".")
|
user_id, app_id, model_id = model.split(".")
|
||||||
|
model_id = model_id.lower()
|
||||||
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
|
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,19 +172,55 @@ async def async_completion(
|
||||||
|
|
||||||
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
response = await async_handler.post(
|
response = await async_handler.post(
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
url=model, headers=headers, data=json.dumps(data)
|
||||||
)
|
)
|
||||||
|
|
||||||
return process_response(
|
logging_obj.post_call(
|
||||||
model=model,
|
input=prompt,
|
||||||
prompt=prompt,
|
|
||||||
response=response,
|
|
||||||
model_response=model_response,
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
original_response=response.text,
|
||||||
encoding=encoding,
|
additional_args={"complete_input_dict": data},
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except Exception:
|
||||||
|
raise ClarifaiError(
|
||||||
|
message=response.text, status_code=response.status_code, url=model
|
||||||
|
)
|
||||||
|
# print(completion_response)
|
||||||
|
try:
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(completion_response["outputs"]):
|
||||||
|
if len(item["data"]["text"]["raw"]) > 0:
|
||||||
|
message_obj = Message(content=item["data"]["text"]["raw"])
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=idx + 1, # check
|
||||||
|
message=message_obj,
|
||||||
|
)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response["choices"] = choices_list
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ClarifaiError(
|
||||||
|
message=traceback.format_exc(), status_code=response.status_code, url=model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate Usage
|
||||||
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content"))
|
||||||
|
)
|
||||||
|
model_response["model"] = model
|
||||||
|
model_response["usage"] = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -241,7 +278,7 @@ def completion(
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
"api_base": api_base,
|
"api_base": model,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion == True:
|
if acompletion == True:
|
||||||
|
|
|
@ -12,6 +12,15 @@ class AsyncHTTPHandler:
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
concurrent_limit=1000,
|
concurrent_limit=1000,
|
||||||
):
|
):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.client = self.create_client(
|
||||||
|
timeout=timeout, concurrent_limit=concurrent_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_client(
|
||||||
|
self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int
|
||||||
|
) -> httpx.AsyncClient:
|
||||||
|
|
||||||
async_proxy_mounts = None
|
async_proxy_mounts = None
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
||||||
http_proxy = os.getenv("HTTP_PROXY", None)
|
http_proxy = os.getenv("HTTP_PROXY", None)
|
||||||
|
@ -39,7 +48,8 @@ class AsyncHTTPHandler:
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = _DEFAULT_TIMEOUT
|
timeout = _DEFAULT_TIMEOUT
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
self.client = httpx.AsyncClient(
|
|
||||||
|
return httpx.AsyncClient(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=concurrent_limit,
|
max_connections=concurrent_limit,
|
||||||
|
@ -83,11 +93,48 @@ class AsyncHTTPHandler:
|
||||||
response = await self.client.send(req, stream=stream)
|
response = await self.client.send(req, stream=stream)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
|
except httpx.RemoteProtocolError:
|
||||||
|
# Retry the request with a new session if there is a connection error
|
||||||
|
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1)
|
||||||
|
try:
|
||||||
|
return await self.single_connection_post_request(
|
||||||
|
url=url,
|
||||||
|
client=new_client,
|
||||||
|
data=data,
|
||||||
|
json=json,
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await new_client.aclose()
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def single_connection_post_request(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||||
|
json: Optional[dict] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Making POST request for a single connection client.
|
||||||
|
|
||||||
|
Used for retrying connection client errors.
|
||||||
|
"""
|
||||||
|
req = client.build_request(
|
||||||
|
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||||
|
)
|
||||||
|
response = await client.send(req, stream=stream)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
try:
|
try:
|
||||||
asyncio.get_running_loop().create_task(self.close())
|
asyncio.get_running_loop().create_task(self.close())
|
||||||
|
@ -172,3 +219,60 @@ class HTTPHandler:
|
||||||
self.close()
|
self.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler:
|
||||||
|
"""
|
||||||
|
Retrieves the async HTTP client from the cache
|
||||||
|
If not present, creates a new client
|
||||||
|
|
||||||
|
Caches the new client and returns it.
|
||||||
|
"""
|
||||||
|
_params_key_name = ""
|
||||||
|
if params is not None:
|
||||||
|
for key, value in params.items():
|
||||||
|
try:
|
||||||
|
_params_key_name += f"{key}_{value}"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_cache_key_name = "async_httpx_client" + _params_key_name
|
||||||
|
if _cache_key_name in litellm.in_memory_llm_clients_cache:
|
||||||
|
return litellm.in_memory_llm_clients_cache[_cache_key_name]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
_new_client = AsyncHTTPHandler(**params)
|
||||||
|
else:
|
||||||
|
_new_client = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
|
||||||
|
return _new_client
|
||||||
|
|
||||||
|
|
||||||
|
def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
|
||||||
|
"""
|
||||||
|
Retrieves the HTTP client from the cache
|
||||||
|
If not present, creates a new client
|
||||||
|
|
||||||
|
Caches the new client and returns it.
|
||||||
|
"""
|
||||||
|
_params_key_name = ""
|
||||||
|
if params is not None:
|
||||||
|
for key, value in params.items():
|
||||||
|
try:
|
||||||
|
_params_key_name += f"{key}_{value}"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_cache_key_name = "httpx_client" + _params_key_name
|
||||||
|
if _cache_key_name in litellm.in_memory_llm_clients_cache:
|
||||||
|
return litellm.in_memory_llm_clients_cache[_cache_key_name]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
_new_client = HTTPHandler(**params)
|
||||||
|
else:
|
||||||
|
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
|
|
||||||
|
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
|
||||||
|
return _new_client
|
||||||
|
|
|
@ -10,10 +10,10 @@ from typing import Callable, Optional, List, Union, Tuple, Literal
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
Usage,
|
Usage,
|
||||||
map_finish_reason,
|
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
)
|
)
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
@ -289,7 +289,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: litellm.utils.Logging,
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: Union[dict, str],
|
data: Union[dict, str],
|
||||||
|
|
|
@ -1,14 +1,22 @@
|
||||||
import types
|
####################################
|
||||||
import traceback
|
######### DEPRECATED FILE ##########
|
||||||
|
####################################
|
||||||
|
# logic moved to `vertex_httpx.py` #
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
import types
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
|
||||||
import litellm
|
|
||||||
import httpx
|
import httpx
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
|
||||||
|
|
||||||
|
|
||||||
class GeminiError(Exception):
|
class GeminiError(Exception):
|
||||||
|
@ -186,8 +194,8 @@ def completion(
|
||||||
if _system_instruction and len(system_prompt) > 0:
|
if _system_instruction and len(system_prompt) > 0:
|
||||||
_params["system_instruction"] = system_prompt
|
_params["system_instruction"] = system_prompt
|
||||||
_model = genai.GenerativeModel(**_params)
|
_model = genai.GenerativeModel(**_params)
|
||||||
if stream == True:
|
if stream is True:
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
|
|
||||||
async def async_streaming():
|
async def async_streaming():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,33 +1,41 @@
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import types
|
||||||
from typing import (
|
from typing import (
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
Any,
|
Any,
|
||||||
BinaryIO,
|
BinaryIO,
|
||||||
Literal,
|
Callable,
|
||||||
|
Coroutine,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
import hashlib
|
|
||||||
from typing_extensions import override, overload
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import types, time, json, traceback
|
|
||||||
import httpx
|
import httpx
|
||||||
from .base import BaseLLM
|
|
||||||
from litellm.utils import (
|
|
||||||
ModelResponse,
|
|
||||||
Choices,
|
|
||||||
Message,
|
|
||||||
CustomStreamWrapper,
|
|
||||||
convert_to_model_response_object,
|
|
||||||
Usage,
|
|
||||||
TranscriptionResponse,
|
|
||||||
TextCompletionResponse,
|
|
||||||
)
|
|
||||||
from typing import Callable, Optional, Coroutine
|
|
||||||
import litellm
|
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
|
||||||
from openai import OpenAI, AsyncOpenAI
|
|
||||||
from ..types.llms.openai import *
|
|
||||||
import openai
|
import openai
|
||||||
|
from openai import AsyncOpenAI, OpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import overload, override
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.utils import ProviderField
|
||||||
|
from litellm.utils import (
|
||||||
|
Choices,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
Usage,
|
||||||
|
convert_to_model_response_object,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..types.llms.openai import *
|
||||||
|
from .base import BaseLLM
|
||||||
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(Exception):
|
class OpenAIError(Exception):
|
||||||
|
@ -164,6 +172,68 @@ class MistralConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class MistralEmbeddingConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"encoding_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "encoding_format":
|
||||||
|
optional_params["encoding_format"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAIStudioConfig:
|
||||||
|
def get_required_params(self) -> List[ProviderField]:
|
||||||
|
"""For a given provider, return it's required fields with a description"""
|
||||||
|
return [
|
||||||
|
ProviderField(
|
||||||
|
field_name="api_key",
|
||||||
|
field_type="string",
|
||||||
|
field_description="Your Azure AI Studio API Key.",
|
||||||
|
field_value="zEJ...",
|
||||||
|
),
|
||||||
|
ProviderField(
|
||||||
|
field_name="api_base",
|
||||||
|
field_type="string",
|
||||||
|
field_description="Your Azure AI Studio API Base.",
|
||||||
|
field_value="https://Mistral-serverless.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DeepInfraConfig:
|
class DeepInfraConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://deepinfra.com/docs/advanced/openai_api
|
Reference: https://deepinfra.com/docs/advanced/openai_api
|
||||||
|
@ -243,8 +313,12 @@ class DeepInfraConfig:
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, non_default_params: dict, optional_params: dict, model: str
|
self,
|
||||||
):
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
supported_openai_params = self.get_supported_openai_params()
|
supported_openai_params = self.get_supported_openai_params()
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if (
|
if (
|
||||||
|
@ -253,8 +327,23 @@ class DeepInfraConfig:
|
||||||
and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
): # this model does no support temperature == 0
|
): # this model does no support temperature == 0
|
||||||
value = 0.0001 # close to 0
|
value = 0.0001 # close to 0
|
||||||
|
if param == "tool_choice":
|
||||||
|
if (
|
||||||
|
value != "auto" and value != "none"
|
||||||
|
): # https://deepinfra.com/docs/advanced/function_calling
|
||||||
|
## UNSUPPORTED TOOL CHOICE VALUE
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
value = None
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
||||||
|
value
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
if param in supported_openai_params:
|
if param in supported_openai_params:
|
||||||
optional_params[param] = value
|
if value is not None:
|
||||||
|
optional_params[param] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,11 +12,11 @@ from typing import Callable, Optional, List, Literal, Union
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
Usage,
|
Usage,
|
||||||
map_finish_reason,
|
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
Message,
|
Message,
|
||||||
Choices,
|
Choices,
|
||||||
)
|
)
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
@ -198,7 +198,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: litellm.utils.Logging,
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: Union[dict, str],
|
data: Union[dict, str],
|
||||||
|
|
|
@ -1,24 +1,30 @@
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, traceback
|
|
||||||
import json, re, xml.etree.ElementTree as ET
|
|
||||||
from jinja2 import Template, exceptions, meta, BaseLoader
|
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
||||||
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from jinja2 import BaseLoader, Template, exceptions, meta
|
||||||
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import litellm.types
|
import litellm.types
|
||||||
from litellm.types.completion import (
|
|
||||||
ChatCompletionUserMessageParam,
|
|
||||||
ChatCompletionSystemMessageParam,
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionFunctionMessageParam,
|
|
||||||
ChatCompletionMessageToolCallParam,
|
|
||||||
ChatCompletionToolMessageParam,
|
|
||||||
)
|
|
||||||
import litellm.types.llms
|
import litellm.types.llms
|
||||||
from litellm.types.llms.anthropic import *
|
|
||||||
import uuid
|
|
||||||
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
|
||||||
import litellm.types.llms.vertex_ai
|
import litellm.types.llms.vertex_ai
|
||||||
|
from litellm.types.completion import (
|
||||||
|
ChatCompletionFunctionMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.anthropic import *
|
||||||
|
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||||
|
from litellm.types.utils import GenericImageParsingChunk
|
||||||
|
|
||||||
|
|
||||||
def default_pt(messages):
|
def default_pt(messages):
|
||||||
|
@ -622,9 +628,10 @@ def construct_tool_use_system_prompt(
|
||||||
|
|
||||||
|
|
||||||
def convert_url_to_base64(url):
|
def convert_url_to_base64(url):
|
||||||
import requests
|
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
|
@ -654,7 +661,7 @@ def convert_url_to_base64(url):
|
||||||
raise Exception(f"Error: Unable to fetch image from URL. url={url}")
|
raise Exception(f"Error: Unable to fetch image from URL. url={url}")
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_image_obj(openai_image_url: str):
|
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
||||||
"""
|
"""
|
||||||
Input:
|
Input:
|
||||||
"image_url": "data:image/jpeg;base64,{base64_image}",
|
"image_url": "data:image/jpeg;base64,{base64_image}",
|
||||||
|
@ -675,11 +682,11 @@ def convert_to_anthropic_image_obj(openai_image_url: str):
|
||||||
# Infer image format from the URL
|
# Infer image format from the URL
|
||||||
image_format = openai_image_url.split("data:image/")[1].split(";base64,")[0]
|
image_format = openai_image_url.split("data:image/")[1].split(";base64,")[0]
|
||||||
|
|
||||||
return {
|
return GenericImageParsingChunk(
|
||||||
"type": "base64",
|
type="base64",
|
||||||
"media_type": f"image/{image_format}",
|
media_type=f"image/{image_format}",
|
||||||
"data": base64_data,
|
data=base64_data,
|
||||||
}
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "Error: Unable to fetch image from URL" in str(e):
|
if "Error: Unable to fetch image from URL" in str(e):
|
||||||
raise e
|
raise e
|
||||||
|
@ -1606,19 +1613,23 @@ def azure_text_pt(messages: list):
|
||||||
|
|
||||||
###### AMAZON BEDROCK #######
|
###### AMAZON BEDROCK #######
|
||||||
|
|
||||||
|
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
|
||||||
|
from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock
|
||||||
|
from litellm.types.llms.bedrock import ImageSourceBlock as BedrockImageSourceBlock
|
||||||
|
from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock
|
||||||
from litellm.types.llms.bedrock import (
|
from litellm.types.llms.bedrock import (
|
||||||
ToolResultContentBlock as BedrockToolResultContentBlock,
|
|
||||||
ToolResultBlock as BedrockToolResultBlock,
|
|
||||||
ToolConfigBlock as BedrockToolConfigBlock,
|
|
||||||
ToolUseBlock as BedrockToolUseBlock,
|
|
||||||
ImageSourceBlock as BedrockImageSourceBlock,
|
|
||||||
ImageBlock as BedrockImageBlock,
|
|
||||||
ContentBlock as BedrockContentBlock,
|
|
||||||
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
|
||||||
ToolSpecBlock as BedrockToolSpecBlock,
|
|
||||||
ToolBlock as BedrockToolBlock,
|
|
||||||
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
|
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.bedrock import ToolConfigBlock as BedrockToolConfigBlock
|
||||||
|
from litellm.types.llms.bedrock import (
|
||||||
|
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.bedrock import ToolResultBlock as BedrockToolResultBlock
|
||||||
|
from litellm.types.llms.bedrock import (
|
||||||
|
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.bedrock import ToolSpecBlock as BedrockToolSpecBlock
|
||||||
|
from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
|
||||||
|
|
||||||
|
|
||||||
def get_image_details(image_url) -> Tuple[str, str]:
|
def get_image_details(image_url) -> Tuple[str, str]:
|
||||||
|
@ -1655,7 +1666,8 @@ def get_image_details(image_url) -> Tuple[str, str]:
|
||||||
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
||||||
if "base64" in image_url:
|
if "base64" in image_url:
|
||||||
# Case 1: Images with base64 encoding
|
# Case 1: Images with base64 encoding
|
||||||
import base64, re
|
import base64
|
||||||
|
import re
|
||||||
|
|
||||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||||
image_metadata, img_without_base_64 = image_url.split(",")
|
image_metadata, img_without_base_64 = image_url.split(",")
|
||||||
|
|
532
litellm/llms/text_completion_codestral.py
Normal file
|
@ -0,0 +1,532 @@
|
||||||
|
# What is this?
|
||||||
|
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import os, types
|
||||||
|
import traceback
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List, Literal, Union
|
||||||
|
from litellm.utils import (
|
||||||
|
TextCompletionResponse,
|
||||||
|
Usage,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
Choices,
|
||||||
|
)
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TextCompletionCodestralError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
request: Optional[httpx.Request] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
if request is not None:
|
||||||
|
self.request = request
|
||||||
|
else:
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://docs.codestral.com/user-guide/inference/rest_api",
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
self.response = response
|
||||||
|
else:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=status_code, request=self.request
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
async def make_call(
|
||||||
|
client: AsyncHTTPHandler,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise TextCompletionCodestralError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = response.aiter_lines()
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=completion_stream, # Pass the completion stream for logging
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
|
class MistralTextCompletionConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
||||||
|
"""
|
||||||
|
|
||||||
|
suffix: Optional[str] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
min_tokens: Optional[int] = None
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
random_seed: Optional[int] = None
|
||||||
|
stop: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
min_tokens: Optional[int] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
random_seed: Optional[int] = None,
|
||||||
|
stop: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"suffix",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"seed",
|
||||||
|
"stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "suffix":
|
||||||
|
optional_params["suffix"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "stream" and value == True:
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "seed":
|
||||||
|
optional_params["random_seed"] = value
|
||||||
|
if param == "min_tokens":
|
||||||
|
optional_params["min_tokens"] = value
|
||||||
|
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
|
||||||
|
text = ""
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = None
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
chunk_data = chunk_data.replace("data:", "")
|
||||||
|
chunk_data = chunk_data.strip()
|
||||||
|
if len(chunk_data) == 0 or chunk_data == "[DONE]":
|
||||||
|
return {
|
||||||
|
"text": "",
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
chunk_data_dict = json.loads(chunk_data)
|
||||||
|
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
|
||||||
|
_choices = chunk_data_dict.get("choices", []) or []
|
||||||
|
_choice = _choices[0]
|
||||||
|
text = _choice.get("delta", {}).get("content", "")
|
||||||
|
|
||||||
|
if _choice.get("finish_reason") is not None:
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = _choice.get("finish_reason")
|
||||||
|
logprobs = _choice.get("logprobs")
|
||||||
|
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
original_chunk=original_chunk,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CodestralTextCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _validate_environment(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
user_headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables"
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
}
|
||||||
|
if user_headers is not None and isinstance(user_headers, dict):
|
||||||
|
headers = {**headers, **user_headers}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def output_parser(self, generated_text: str):
|
||||||
|
"""
|
||||||
|
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||||
|
|
||||||
|
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||||
|
"""
|
||||||
|
chat_template_tokens = [
|
||||||
|
"<|assistant|>",
|
||||||
|
"<|system|>",
|
||||||
|
"<|user|>",
|
||||||
|
"<s>",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
for token in chat_template_tokens:
|
||||||
|
if generated_text.strip().startswith(token):
|
||||||
|
generated_text = generated_text.replace(token, "", 1)
|
||||||
|
if generated_text.endswith(token):
|
||||||
|
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||||
|
return generated_text
|
||||||
|
|
||||||
|
def process_text_completion_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: TextCompletionResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> TextCompletionResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"codestral api: raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise TextCompletionCodestralError(
|
||||||
|
message=str(response.text),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise TextCompletionCodestralError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
_original_choices = completion_response.get("choices", [])
|
||||||
|
_choices: List[litellm.utils.TextChoices] = []
|
||||||
|
for choice in _original_choices:
|
||||||
|
# This is what 1 choice looks like from codestral API
|
||||||
|
# {
|
||||||
|
# "index": 0,
|
||||||
|
# "message": {
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": "\n assert is_odd(1)\n assert",
|
||||||
|
# "tool_calls": null
|
||||||
|
# },
|
||||||
|
# "finish_reason": "length",
|
||||||
|
# "logprobs": null
|
||||||
|
# }
|
||||||
|
_finish_reason = None
|
||||||
|
_index = 0
|
||||||
|
_text = None
|
||||||
|
_logprobs = None
|
||||||
|
|
||||||
|
_choice_message = choice.get("message", {})
|
||||||
|
_choice = litellm.utils.TextChoices(
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
index=choice.get("index"),
|
||||||
|
text=_choice_message.get("content"),
|
||||||
|
logprobs=choice.get("logprobs"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_choices.append(_choice)
|
||||||
|
|
||||||
|
_response = litellm.TextCompletionResponse(
|
||||||
|
id=completion_response.get("id"),
|
||||||
|
choices=_choices,
|
||||||
|
created=completion_response.get("created"),
|
||||||
|
model=completion_response.get("model"),
|
||||||
|
usage=completion_response.get("usage"),
|
||||||
|
stream=False,
|
||||||
|
object=completion_response.get("object"),
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: TextCompletionResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key: str,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: dict = {},
|
||||||
|
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
|
||||||
|
headers = self._validate_environment(api_key, headers)
|
||||||
|
|
||||||
|
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
|
||||||
|
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.MistralTextCompletionConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
stream = optional_params.pop("stream", False)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
|
input_text = prompt
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input_text,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": completion_url,
|
||||||
|
"acompletion": acompletion,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if acompletion is True:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if stream is True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
) # type: ignore
|
||||||
|
else:
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### SYNC STREAMING
|
||||||
|
if stream is True:
|
||||||
|
response = requests.post(
|
||||||
|
completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
_response = CustomStreamWrapper(
|
||||||
|
response.iter_lines(),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="codestral",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
### SYNC COMPLETION
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
url=completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
)
|
||||||
|
return self.process_text_completion_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=optional_params.get("stream", False),
|
||||||
|
logging_obj=logging_obj, # type: ignore
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: TextCompletionResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> TextCompletionResponse:
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
|
||||||
|
try:
|
||||||
|
response = await async_handler.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TextCompletionCodestralError(
|
||||||
|
status_code=e.response.status_code,
|
||||||
|
message="HTTPStatusError - {}".format(e.response.text),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise TextCompletionCodestralError(
|
||||||
|
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
|
||||||
|
)
|
||||||
|
return self.process_text_completion_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: TextCompletionResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
data["stream"] = True
|
||||||
|
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="text-completion-codestral",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
|
||||||
|
def embedding(self, *args, **kwargs):
|
||||||
|
pass
|
199998
litellm/llms/tokenizers/fb374d419588a4632f3f557e76b4b70aebbca790
Normal file
|
@ -4,7 +4,6 @@ from enum import Enum
|
||||||
import requests, copy # type: ignore
|
import requests, copy # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
|
@ -1,18 +1,31 @@
|
||||||
import os, types
|
import inspect
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
import os
|
||||||
import requests # type: ignore
|
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union, List, Literal, Any
|
import types
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
import uuid
|
||||||
import litellm, uuid
|
from enum import Enum
|
||||||
import httpx, inspect # type: ignore
|
from typing import Any, Callable, List, Literal, Optional, Union
|
||||||
from litellm.types.llms.vertex_ai import *
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import requests # type: ignore
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.prompt_templates.factory import (
|
from litellm.llms.prompt_templates.factory import (
|
||||||
convert_to_gemini_tool_call_result,
|
convert_to_anthropic_image_obj,
|
||||||
convert_to_gemini_tool_call_invoke,
|
convert_to_gemini_tool_call_invoke,
|
||||||
|
convert_to_gemini_tool_call_result,
|
||||||
)
|
)
|
||||||
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
|
from litellm.types.files import (
|
||||||
|
get_file_mime_type_for_file_type,
|
||||||
|
get_file_type_from_extension,
|
||||||
|
is_gemini_1_5_accepted_file_type,
|
||||||
|
is_video_file_type,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import *
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -267,28 +280,6 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
|
||||||
raise Exception(f"An exception occurs with this image - {str(e)}")
|
raise Exception(f"An exception occurs with this image - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def _load_image_from_url(image_url: str):
|
|
||||||
"""
|
|
||||||
Loads an image from a URL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_url (str): The URL of the image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image: The loaded image.
|
|
||||||
"""
|
|
||||||
from vertexai.preview.generative_models import (
|
|
||||||
GenerativeModel,
|
|
||||||
Part,
|
|
||||||
GenerationConfig,
|
|
||||||
Image,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_bytes = _get_image_bytes_from_url(image_url)
|
|
||||||
|
|
||||||
return Image.from_bytes(data=image_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_gemini_role(role: str) -> Literal["user", "model"]:
|
def _convert_gemini_role(role: str) -> Literal["user", "model"]:
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return "user"
|
return "user"
|
||||||
|
@ -301,43 +292,24 @@ def _process_gemini_image(image_url: str) -> PartType:
|
||||||
# GCS URIs
|
# GCS URIs
|
||||||
if "gs://" in image_url:
|
if "gs://" in image_url:
|
||||||
# Figure out file type
|
# Figure out file type
|
||||||
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||||
extension = extension_with_dot[1:] # Ex: "png"
|
extension = extension_with_dot[1:] # Ex: "png"
|
||||||
|
|
||||||
file_type = get_file_type_from_extension(extension)
|
file_type = get_file_type_from_extension(extension)
|
||||||
|
|
||||||
# Validate the file type is supported by Gemini
|
# Validate the file type is supported by Gemini
|
||||||
if not is_gemini_1_5_accepted_file_type(file_type):
|
if not is_gemini_1_5_accepted_file_type(file_type):
|
||||||
raise Exception(f"File type not supported by gemini - {file_type}")
|
raise Exception(f"File type not supported by gemini - {file_type}")
|
||||||
|
|
||||||
mime_type = get_file_mime_type_for_file_type(file_type)
|
mime_type = get_file_mime_type_for_file_type(file_type)
|
||||||
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
|
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
|
||||||
|
|
||||||
return PartType(file_data=file_data)
|
return PartType(file_data=file_data)
|
||||||
|
|
||||||
# Direct links
|
# Direct links
|
||||||
elif "https:/" in image_url:
|
elif "https:/" in image_url or "base64" in image_url:
|
||||||
image = _load_image_from_url(image_url)
|
image = convert_to_anthropic_image_obj(image_url)
|
||||||
_blob = BlobType(data=image.data, mime_type=image._mime_type)
|
_blob = BlobType(data=image["data"], mime_type=image["media_type"])
|
||||||
return PartType(inline_data=_blob)
|
|
||||||
|
|
||||||
# Base64 encoding
|
|
||||||
elif "base64" in image_url:
|
|
||||||
import base64, re
|
|
||||||
|
|
||||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
|
||||||
image_metadata, img_without_base_64 = image_url.split(",")
|
|
||||||
|
|
||||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
|
||||||
# Extract MIME type using regular expression
|
|
||||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
|
||||||
|
|
||||||
if mime_type_match:
|
|
||||||
mime_type = mime_type_match.group(1)
|
|
||||||
else:
|
|
||||||
mime_type = "image/jpeg"
|
|
||||||
decoded_img = base64.b64decode(img_without_base_64)
|
|
||||||
_blob = BlobType(data=decoded_img, mime_type=mime_type)
|
|
||||||
return PartType(inline_data=_blob)
|
return PartType(inline_data=_blob)
|
||||||
raise Exception("Invalid image received - {}".format(image_url))
|
raise Exception("Invalid image received - {}".format(image_url))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -473,23 +445,25 @@ def completion(
|
||||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
import google.auth # type: ignore
|
||||||
|
import proto # type: ignore
|
||||||
|
from google.cloud import aiplatform # type: ignore
|
||||||
|
from google.cloud.aiplatform_v1beta1.types import (
|
||||||
|
content as gapic_content_types, # type: ignore
|
||||||
|
)
|
||||||
|
from google.protobuf import json_format # type: ignore
|
||||||
|
from google.protobuf.struct_pb2 import Value # type: ignore
|
||||||
|
from vertexai.language_models import CodeGenerationModel, TextGenerationModel
|
||||||
|
from vertexai.preview.generative_models import (
|
||||||
|
GenerationConfig,
|
||||||
|
GenerativeModel,
|
||||||
|
Part,
|
||||||
|
)
|
||||||
from vertexai.preview.language_models import (
|
from vertexai.preview.language_models import (
|
||||||
ChatModel,
|
ChatModel,
|
||||||
CodeChatModel,
|
CodeChatModel,
|
||||||
InputOutputTextPair,
|
InputOutputTextPair,
|
||||||
)
|
)
|
||||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
|
||||||
from vertexai.preview.generative_models import (
|
|
||||||
GenerativeModel,
|
|
||||||
Part,
|
|
||||||
GenerationConfig,
|
|
||||||
)
|
|
||||||
from google.cloud import aiplatform # type: ignore
|
|
||||||
from google.protobuf import json_format # type: ignore
|
|
||||||
from google.protobuf.struct_pb2 import Value # type: ignore
|
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
|
||||||
import google.auth # type: ignore
|
|
||||||
import proto # type: ignore
|
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -611,7 +585,7 @@ def completion(
|
||||||
llm_model = None
|
llm_model = None
|
||||||
|
|
||||||
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
data = {
|
data = {
|
||||||
"llm_model": llm_model,
|
"llm_model": llm_model,
|
||||||
"mode": mode,
|
"mode": mode,
|
||||||
|
@ -643,7 +617,7 @@ def completion(
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
stream = optional_params.pop("stream", False)
|
stream = optional_params.pop("stream", False)
|
||||||
if stream == True:
|
if stream is True:
|
||||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -1293,6 +1267,95 @@ async def async_streaming(
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAITextEmbeddingConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
|
||||||
|
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
|
||||||
|
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
auto_truncate: Optional[bool] = None
|
||||||
|
task_type: Optional[
|
||||||
|
Literal[
|
||||||
|
"RETRIEVAL_QUERY",
|
||||||
|
"RETRIEVAL_DOCUMENT",
|
||||||
|
"SEMANTIC_SIMILARITY",
|
||||||
|
"CLASSIFICATION",
|
||||||
|
"CLUSTERING",
|
||||||
|
"QUESTION_ANSWERING",
|
||||||
|
"FACT_VERIFICATION",
|
||||||
|
]
|
||||||
|
] = None
|
||||||
|
title: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
auto_truncate: Optional[bool] = None,
|
||||||
|
task_type: Optional[
|
||||||
|
Literal[
|
||||||
|
"RETRIEVAL_QUERY",
|
||||||
|
"RETRIEVAL_DOCUMENT",
|
||||||
|
"SEMANTIC_SIMILARITY",
|
||||||
|
"CLASSIFICATION",
|
||||||
|
"CLUSTERING",
|
||||||
|
"QUESTION_ANSWERING",
|
||||||
|
"FACT_VERIFICATION",
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"dimensions",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "dimensions":
|
||||||
|
optional_params["output_dimensionality"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
model: str,
|
model: str,
|
||||||
input: Union[list, str],
|
input: Union[list, str],
|
||||||
|
@ -1316,8 +1379,8 @@ def embedding(
|
||||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||||
)
|
)
|
||||||
|
|
||||||
from vertexai.language_models import TextEmbeddingModel
|
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
try:
|
try:
|
||||||
|
@ -1347,6 +1410,16 @@ def embedding(
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
input = [input]
|
input = [input]
|
||||||
|
|
||||||
|
if optional_params is not None and isinstance(optional_params, dict):
|
||||||
|
if optional_params.get("task_type") or optional_params.get("title"):
|
||||||
|
# if user passed task_type or title, cast to TextEmbeddingInput
|
||||||
|
_task_type = optional_params.pop("task_type", None)
|
||||||
|
_title = optional_params.pop("title", None)
|
||||||
|
input = [
|
||||||
|
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
|
||||||
|
for x in input
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1363,7 +1436,8 @@ def embedding(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
|
_input_dict = {"texts": input, **optional_params}
|
||||||
|
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||||
## LOGGING PRE-CALL
|
## LOGGING PRE-CALL
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1375,7 +1449,7 @@ def embedding(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embeddings = llm_model.get_embeddings(input)
|
embeddings = llm_model.get_embeddings(**_input_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
@ -1383,6 +1457,7 @@ def embedding(
|
||||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
|
input_tokens: int = 0
|
||||||
for idx, embedding in enumerate(embeddings):
|
for idx, embedding in enumerate(embeddings):
|
||||||
embedding_response.append(
|
embedding_response.append(
|
||||||
{
|
{
|
||||||
|
@ -1391,14 +1466,10 @@ def embedding(
|
||||||
"embedding": embedding.values,
|
"embedding": embedding.values,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
input_tokens += embedding.statistics.token_count
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
input_tokens = 0
|
|
||||||
|
|
||||||
input_str = "".join(input)
|
|
||||||
|
|
||||||
input_tokens += len(encoding.encode(input_str))
|
|
||||||
|
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
@ -1420,7 +1491,8 @@ async def async_embedding(
|
||||||
"""
|
"""
|
||||||
Async embedding implementation
|
Async embedding implementation
|
||||||
"""
|
"""
|
||||||
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
|
_input_dict = {"texts": input, **optional_params}
|
||||||
|
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||||
## LOGGING PRE-CALL
|
## LOGGING PRE-CALL
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1432,7 +1504,7 @@ async def async_embedding(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embeddings = await client.get_embeddings_async(input)
|
embeddings = await client.get_embeddings_async(**_input_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
@ -1440,6 +1512,7 @@ async def async_embedding(
|
||||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
|
input_tokens: int = 0
|
||||||
for idx, embedding in enumerate(embeddings):
|
for idx, embedding in enumerate(embeddings):
|
||||||
embedding_response.append(
|
embedding_response.append(
|
||||||
{
|
{
|
||||||
|
@ -1448,18 +1521,13 @@ async def async_embedding(
|
||||||
"embedding": embedding.values,
|
"embedding": embedding.values,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
input_tokens += embedding.statistics.token_count
|
||||||
|
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
input_tokens = 0
|
|
||||||
|
|
||||||
input_str = "".join(input)
|
|
||||||
|
|
||||||
input_tokens += len(encoding.encode(input_str))
|
|
||||||
|
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -6,7 +6,8 @@ from enum import Enum
|
||||||
import requests, copy # type: ignore
|
import requests, copy # type: ignore
|
||||||
import time, uuid
|
import time, uuid
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
|
|
|
@ -1,16 +1,325 @@
|
||||||
import os, types
|
# What is this?
|
||||||
|
## httpx client for vertex ai calls
|
||||||
|
## Initial implementation - covers gemini + image gen calls
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
import os
|
||||||
import requests # type: ignore
|
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union, List, Any, Tuple
|
import types
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
import uuid
|
||||||
import litellm, uuid
|
from enum import Enum
|
||||||
import httpx, inspect # type: ignore
|
from functools import partial
|
||||||
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import ijson
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
import litellm.litellm_core_utils
|
||||||
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
|
from litellm import verbose_logger
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.llms.prompt_templates.factory import convert_url_to_base64
|
||||||
|
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionResponseMessage,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
ContentType,
|
||||||
|
FunctionCallingConfig,
|
||||||
|
FunctionDeclaration,
|
||||||
|
GenerateContentResponseBody,
|
||||||
|
GenerationConfig,
|
||||||
|
PartType,
|
||||||
|
RequestBody,
|
||||||
|
SafetSettingsConfig,
|
||||||
|
SystemInstructions,
|
||||||
|
ToolConfig,
|
||||||
|
Tools,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class VertexGeminiConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
|
|
||||||
|
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `temperature` (float): This controls the degree of randomness in token selection.
|
||||||
|
|
||||||
|
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
|
||||||
|
|
||||||
|
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
|
||||||
|
|
||||||
|
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||||
|
|
||||||
|
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
|
||||||
|
|
||||||
|
- `candidate_count` (int): Number of generated responses to return.
|
||||||
|
|
||||||
|
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||||
|
|
||||||
|
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
|
||||||
|
|
||||||
|
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
|
||||||
|
|
||||||
|
Note: Please make sure to modify the default parameters as required for your use case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_output_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
response_mime_type: Optional[str] = None
|
||||||
|
candidate_count: Optional[int] = None
|
||||||
|
stop_sequences: Optional[list] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[list] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"response_format",
|
||||||
|
"n",
|
||||||
|
"stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_tool_choice_values(
|
||||||
|
self, model: str, tool_choice: Union[str, dict]
|
||||||
|
) -> Optional[ToolConfig]:
|
||||||
|
if tool_choice == "none":
|
||||||
|
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE"))
|
||||||
|
elif tool_choice == "required":
|
||||||
|
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY"))
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO"))
|
||||||
|
elif isinstance(tool_choice, dict):
|
||||||
|
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
name = tool_choice.get("function", {}).get("name", "")
|
||||||
|
return ToolConfig(
|
||||||
|
functionCallingConfig=FunctionCallingConfig(
|
||||||
|
mode="ANY", allowed_function_names=[name]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
tool_choice
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if (
|
||||||
|
param == "stream" and value is True
|
||||||
|
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["candidate_count"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
optional_params["stop_sequences"] = [value]
|
||||||
|
elif isinstance(value, list):
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_output_tokens"] = value
|
||||||
|
if param == "response_format" and value["type"] == "json_object": # type: ignore
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
|
if param == "frequency_penalty":
|
||||||
|
optional_params["frequency_penalty"] = value
|
||||||
|
if param == "presence_penalty":
|
||||||
|
optional_params["presence_penalty"] = value
|
||||||
|
if param == "tools" and isinstance(value, list):
|
||||||
|
gtool_func_declarations = []
|
||||||
|
for tool in value:
|
||||||
|
gtool_func_declaration = FunctionDeclaration(
|
||||||
|
name=tool["function"]["name"],
|
||||||
|
description=tool["function"].get("description", ""),
|
||||||
|
parameters=tool["function"].get("parameters", {}),
|
||||||
|
)
|
||||||
|
gtool_func_declarations.append(gtool_func_declaration)
|
||||||
|
optional_params["tools"] = [
|
||||||
|
Tools(function_declarations=gtool_func_declarations)
|
||||||
|
]
|
||||||
|
if param == "tool_choice" and (
|
||||||
|
isinstance(value, str) or isinstance(value, dict)
|
||||||
|
):
|
||||||
|
_tool_choice_value = self.map_tool_choice_values(
|
||||||
|
model=model, tool_choice=value # type: ignore
|
||||||
|
)
|
||||||
|
if _tool_choice_value is not None:
|
||||||
|
optional_params["tool_choice"] = _tool_choice_value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"europe-central2",
|
||||||
|
"europe-north1",
|
||||||
|
"europe-southwest1",
|
||||||
|
"europe-west1",
|
||||||
|
"europe-west2",
|
||||||
|
"europe-west3",
|
||||||
|
"europe-west4",
|
||||||
|
"europe-west6",
|
||||||
|
"europe-west8",
|
||||||
|
"europe-west9",
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_flagged_finish_reasons(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Return Dictionary of finish reasons which indicate response was flagged
|
||||||
|
|
||||||
|
and what it means
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"SAFETY": "The token generation was stopped as the response was flagged for safety reasons. NOTE: When streaming the Candidate.content will be empty if content filters blocked the output.",
|
||||||
|
"RECITATION": "The token generation was stopped as the response was flagged for unauthorized citations.",
|
||||||
|
"BLOCKLIST": "The token generation was stopped as the response was flagged for the terms which are included from the terminology blocklist.",
|
||||||
|
"PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for the prohibited contents.",
|
||||||
|
"SPII": "The token generation was stopped as the response was flagged for Sensitive Personally Identifiable Information (SPII) contents.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def make_call(
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = AsyncHTTPHandler() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise VertexAIError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
streaming_response=response.aiter_bytes(), sync_stream=False
|
||||||
|
)
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
|
def make_sync_call(
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = HTTPHandler() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise VertexAIError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
|
@ -33,16 +342,159 @@ class VertexLLM(BaseLLM):
|
||||||
self.project_id: Optional[str] = None
|
self.project_id: Optional[str] = None
|
||||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||||
|
|
||||||
def load_auth(self) -> Tuple[Any, str]:
|
def _process_response(
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
self,
|
||||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
model: str,
|
||||||
import google.auth as google_auth
|
response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
|
||||||
credentials, project_id = google_auth.default(
|
## LOGGING
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials.refresh(Request())
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(
|
||||||
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
response.text, str(e)
|
||||||
|
),
|
||||||
|
status_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
|
## CHECK IF RESPONSE FLAGGED
|
||||||
|
if len(completion_response["candidates"]) > 0:
|
||||||
|
content_policy_violations = (
|
||||||
|
VertexGeminiConfig().get_flagged_finish_reasons()
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"finishReason" in completion_response["candidates"][0]
|
||||||
|
and completion_response["candidates"][0]["finishReason"]
|
||||||
|
in content_policy_violations.keys()
|
||||||
|
):
|
||||||
|
## CONTENT POLICY VIOLATION ERROR
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=400,
|
||||||
|
message="The response was blocked. Reason={}. Raw Response={}".format(
|
||||||
|
content_policy_violations[
|
||||||
|
completion_response["candidates"][0]["finishReason"]
|
||||||
|
],
|
||||||
|
completion_response,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.choices = [] # type: ignore
|
||||||
|
|
||||||
|
## GET MODEL ##
|
||||||
|
model_response.model = model
|
||||||
|
|
||||||
|
try:
|
||||||
|
## GET TEXT ##
|
||||||
|
chat_completion_message: ChatCompletionResponseMessage = {
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
content_str = ""
|
||||||
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
|
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||||
|
if "content" not in candidate:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "text" in candidate["content"]["parts"][0]:
|
||||||
|
content_str = candidate["content"]["parts"][0]["text"]
|
||||||
|
|
||||||
|
if "functionCall" in candidate["content"]["parts"][0]:
|
||||||
|
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||||
|
name=candidate["content"]["parts"][0]["functionCall"]["name"],
|
||||||
|
arguments=json.dumps(
|
||||||
|
candidate["content"]["parts"][0]["functionCall"]["args"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||||
|
id=f"call_{str(uuid.uuid4())}",
|
||||||
|
type="function",
|
||||||
|
function=_function_chunk,
|
||||||
|
)
|
||||||
|
tools.append(_tool_response_chunk)
|
||||||
|
|
||||||
|
chat_completion_message["content"] = content_str
|
||||||
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
|
||||||
|
choice = litellm.Choices(
|
||||||
|
finish_reason=candidate.get("finishReason", "stop"),
|
||||||
|
index=candidate.get("index", idx),
|
||||||
|
message=chat_completion_message, # type: ignore
|
||||||
|
logprobs=None,
|
||||||
|
enhancements=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.choices.append(choice)
|
||||||
|
|
||||||
|
## GET USAGE ##
|
||||||
|
usage = litellm.Usage(
|
||||||
|
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
|
||||||
|
completion_tokens=completion_response["usageMetadata"][
|
||||||
|
"candidatesTokenCount"
|
||||||
|
],
|
||||||
|
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(
|
||||||
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
completion_response, str(e)
|
||||||
|
),
|
||||||
|
status_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||||
|
return vertex_region or "us-central1"
|
||||||
|
|
||||||
|
def load_auth(
|
||||||
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[Any, str]:
|
||||||
|
import google.auth as google_auth
|
||||||
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
|
from google.auth.transport.requests import (
|
||||||
|
Request, # type: ignore[import-untyped]
|
||||||
|
)
|
||||||
|
|
||||||
|
if credentials is not None and isinstance(credentials, str):
|
||||||
|
import google.oauth2.service_account
|
||||||
|
|
||||||
|
json_obj = json.loads(credentials)
|
||||||
|
|
||||||
|
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||||
|
json_obj,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if project_id is None:
|
||||||
|
project_id = creds.project_id
|
||||||
|
else:
|
||||||
|
creds, project_id = google_auth.default(
|
||||||
|
quota_project_id=project_id,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
|
||||||
|
creds.refresh(Request())
|
||||||
|
|
||||||
if not project_id:
|
if not project_id:
|
||||||
raise ValueError("Could not resolve project_id")
|
raise ValueError("Could not resolve project_id")
|
||||||
|
@ -52,38 +504,364 @@ class VertexLLM(BaseLLM):
|
||||||
f"Expected project_id to be a str but got {type(project_id)}"
|
f"Expected project_id to be a str but got {type(project_id)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return credentials, project_id
|
return creds, project_id
|
||||||
|
|
||||||
def refresh_auth(self, credentials: Any) -> None:
|
def refresh_auth(self, credentials: Any) -> None:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
from google.auth.transport.requests import (
|
||||||
|
Request, # type: ignore[import-untyped]
|
||||||
|
)
|
||||||
|
|
||||||
credentials.refresh(Request())
|
credentials.refresh(Request())
|
||||||
|
|
||||||
def _prepare_request(self, request: httpx.Request) -> None:
|
def _ensure_access_token(
|
||||||
access_token = self._ensure_access_token()
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
if request.headers.get("Authorization"):
|
"""
|
||||||
# already authenticated, nothing for us to do
|
Returns auth token and project id
|
||||||
return
|
"""
|
||||||
|
if self.access_token is not None and self.project_id is not None:
|
||||||
request.headers["Authorization"] = f"Bearer {access_token}"
|
return self.access_token, self.project_id
|
||||||
|
|
||||||
def _ensure_access_token(self) -> str:
|
|
||||||
if self.access_token is not None:
|
|
||||||
return self.access_token
|
|
||||||
|
|
||||||
if not self._credentials:
|
if not self._credentials:
|
||||||
self._credentials, project_id = self.load_auth()
|
self._credentials, project_id = self.load_auth(
|
||||||
|
credentials=credentials, project_id=project_id
|
||||||
|
)
|
||||||
if not self.project_id:
|
if not self.project_id:
|
||||||
self.project_id = project_id
|
self.project_id = project_id
|
||||||
else:
|
else:
|
||||||
self.refresh_auth(self._credentials)
|
self.refresh_auth(self._credentials)
|
||||||
|
|
||||||
if not self._credentials.token:
|
if not self.project_id:
|
||||||
|
self.project_id = self._credentials.project_id
|
||||||
|
|
||||||
|
if not self.project_id:
|
||||||
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
|
if not self._credentials or not self._credentials.token:
|
||||||
raise RuntimeError("Could not resolve API token from the environment")
|
raise RuntimeError("Could not resolve API token from the environment")
|
||||||
|
|
||||||
assert isinstance(self._credentials.token, str)
|
return self._credentials.token, self.project_id
|
||||||
return self._credentials.token
|
|
||||||
|
def _get_token_and_url(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||||
|
) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Internal function. Returns the token and url for the call.
|
||||||
|
|
||||||
|
Handles logic if it's google ai studio vs. vertex ai.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
token, url
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
_gemini_model_name = "models/{}".format(model)
|
||||||
|
auth_header = None
|
||||||
|
endpoint = "generateContent"
|
||||||
|
if stream is True:
|
||||||
|
endpoint = "streamGenerateContent"
|
||||||
|
|
||||||
|
url = (
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||||
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint = "generateContent"
|
||||||
|
if stream is True:
|
||||||
|
endpoint = "streamGenerateContent"
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||||
|
|
||||||
|
return auth_header, url
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self._process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
custom_llm_provider: Literal[
|
||||||
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
], # if it's vertex_ai or gemini (google ai studio)
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||||
|
|
||||||
|
auth_header, url = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=gemini_api_key,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=stream,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
## TRANSFORMATION ##
|
||||||
|
try:
|
||||||
|
supports_system_message = litellm.supports_system_messages(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
supports_system_message = False
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_content_blocks: List[PartType] = []
|
||||||
|
if supports_system_message is True:
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
_system_content_block = PartType(text=message["content"])
|
||||||
|
system_content_blocks.append(_system_content_block)
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||||
|
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||||
|
"safety_settings", None
|
||||||
|
) # type: ignore
|
||||||
|
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
data = RequestBody(contents=content)
|
||||||
|
if len(system_content_blocks) > 0:
|
||||||
|
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||||
|
data["system_instruction"] = system_instructions
|
||||||
|
if tools is not None:
|
||||||
|
data["tools"] = tools
|
||||||
|
if tool_choice is not None:
|
||||||
|
data["toolConfig"] = tool_choice
|
||||||
|
if safety_settings is not None:
|
||||||
|
data["safetySettings"] = safety_settings
|
||||||
|
if generation_config is not None:
|
||||||
|
data["generationConfig"] = generation_config
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
}
|
||||||
|
if auth_header is not None:
|
||||||
|
headers["Authorization"] = f"Bearer {auth_header}"
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if stream is True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=json.dumps(data), # type: ignore
|
||||||
|
api_base=url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client, # type: ignore
|
||||||
|
)
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data, # type: ignore
|
||||||
|
api_base=url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
## SYNC STREAMING CALL ##
|
||||||
|
if stream is not None and stream is True:
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_sync_call,
|
||||||
|
client=None,
|
||||||
|
api_base=url,
|
||||||
|
headers=headers, # type: ignore
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return streaming_response
|
||||||
|
## COMPLETION CALL ##
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.post(url=url, headers=headers, json=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self._process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data, # type: ignore
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
def image_generation(
|
def image_generation(
|
||||||
self,
|
self,
|
||||||
|
@ -163,7 +941,7 @@ class VertexLLM(BaseLLM):
|
||||||
} \
|
} \
|
||||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||||
"""
|
"""
|
||||||
auth_header = self._ensure_access_token()
|
auth_header, _ = self._ensure_access_token(credentials=None, project_id=None)
|
||||||
optional_params = optional_params or {
|
optional_params = optional_params or {
|
||||||
"sampleCount": 1
|
"sampleCount": 1
|
||||||
} # default optional params
|
} # default optional params
|
||||||
|
@ -222,3 +1000,110 @@ class VertexLLM(BaseLLM):
|
||||||
model_response.data = _response_data
|
model_response.data = _response_data
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
|
self.streaming_response = streaming_response
|
||||||
|
if sync_stream:
|
||||||
|
self.response_iterator = iter(self.streaming_response)
|
||||||
|
|
||||||
|
self.events = ijson.sendable_list()
|
||||||
|
self.coro = ijson.items_coro(self.events, "item")
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
|
||||||
|
text = ""
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
|
||||||
|
gemini_chunk = processed_chunk["candidates"][0]
|
||||||
|
|
||||||
|
if (
|
||||||
|
"content" in gemini_chunk
|
||||||
|
and "text" in gemini_chunk["content"]["parts"][0]
|
||||||
|
):
|
||||||
|
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||||
|
|
||||||
|
if "finishReason" in gemini_chunk:
|
||||||
|
finish_reason = map_finish_reason(
|
||||||
|
finish_reason=gemini_chunk["finishReason"]
|
||||||
|
)
|
||||||
|
is_finished = True
|
||||||
|
|
||||||
|
if "usageMetadata" in processed_chunk:
|
||||||
|
usage = ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
|
||||||
|
completion_tokens=processed_chunk["usageMetadata"][
|
||||||
|
"candidatesTokenCount"
|
||||||
|
],
|
||||||
|
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
|
||||||
|
)
|
||||||
|
|
||||||
|
returned_chunk = GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
return returned_chunk
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
chunk = self.response_iterator.__next__()
|
||||||
|
self.coro.send(chunk)
|
||||||
|
if self.events:
|
||||||
|
event = self.events[0]
|
||||||
|
json_chunk = event
|
||||||
|
self.events.clear()
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e}")
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
self.coro.send(chunk)
|
||||||
|
if self.events:
|
||||||
|
event = self.events[0]
|
||||||
|
json_chunk = event
|
||||||
|
self.events.clear()
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e}")
|
||||||
|
|