Merge branch 'main' into litellm_aws_kms_fixes

This commit is contained in:
Krish Dholakia 2024-06-19 09:30:54 -07:00 committed by GitHub
commit 3a3b3667ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
213 changed files with 221464 additions and 13619 deletions

View file

@ -65,6 +65,7 @@ jobs:
pip install "pydantic==2.7.1" pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1" pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0" pip install "Pillow==10.3.0"
pip install "ijson==3.2.3"
- save_cache: - save_cache:
paths: paths:
- ./venv - ./venv
@ -126,6 +127,7 @@ jobs:
pip install jinja2 pip install jinja2
pip install tokenizers pip install tokenizers
pip install openai pip install openai
pip install ijson
- run: - run:
name: Run tests name: Run tests
command: | command: |
@ -180,6 +182,7 @@ jobs:
pip install numpydoc pip install numpydoc
pip install prisma pip install prisma
pip install fastapi pip install fastapi
pip install ijson
pip install "httpx==0.24.1" pip install "httpx==0.24.1"
pip install "gunicorn==21.2.0" pip install "gunicorn==21.2.0"
pip install "anyio==3.7.1" pip install "anyio==3.7.1"
@ -202,6 +205,7 @@ jobs:
-e REDIS_PORT=$REDIS_PORT \ -e REDIS_PORT=$REDIS_PORT \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e MISTRAL_API_KEY=$MISTRAL_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AWS_REGION_NAME=$AWS_REGION_NAME \

10
.github/dependabot.yaml vendored Normal file
View file

@ -0,0 +1,10 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
groups:
github-actions:
patterns:
- "*"

View file

@ -25,6 +25,11 @@ jobs:
if: github.repository == 'BerriAI/litellm' if: github.repository == 'BerriAI/litellm'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
-
name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.commit_hash }}
- -
name: Set up QEMU name: Set up QEMU
uses: docker/setup-qemu-action@v3 uses: docker/setup-qemu-action@v3
@ -41,12 +46,14 @@ jobs:
name: Build and push name: Build and push
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: .
push: true push: true
tags: litellm/litellm:${{ github.event.inputs.tag || 'latest' }} tags: litellm/litellm:${{ github.event.inputs.tag || 'latest' }}
- -
name: Build and push litellm-database image name: Build and push litellm-database image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: .
push: true push: true
file: Dockerfile.database file: Dockerfile.database
tags: litellm/litellm-database:${{ github.event.inputs.tag || 'latest' }} tags: litellm/litellm-database:${{ github.event.inputs.tag || 'latest' }}
@ -54,6 +61,7 @@ jobs:
name: Build and push litellm-spend-logs image name: Build and push litellm-spend-logs image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: .
push: true push: true
file: ./litellm-js/spend-logs/Dockerfile file: ./litellm-js/spend-logs/Dockerfile
tags: litellm/litellm-spend_logs:${{ github.event.inputs.tag || 'latest' }} tags: litellm/litellm-spend_logs:${{ github.event.inputs.tag || 'latest' }}
@ -68,6 +76,8 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.commit_hash }}
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
- name: Log in to the Container registry - name: Log in to the Container registry
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
@ -92,7 +102,7 @@ jobs:
- name: Build and push Docker image - name: Build and push Docker image
uses: docker/build-push-action@4976231911ebf5f32aad765192d35f942aa48cb8 uses: docker/build-push-action@4976231911ebf5f32aad765192d35f942aa48cb8
with: with:
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}} context: .
push: true push: true
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest' tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
@ -106,6 +116,8 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.commit_hash }}
- name: Log in to the Container registry - name: Log in to the Container registry
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
@ -128,7 +140,7 @@ jobs:
- name: Build and push Database Docker image - name: Build and push Database Docker image
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
with: with:
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}} context: .
file: Dockerfile.database file: Dockerfile.database
push: true push: true
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }} tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
@ -143,6 +155,8 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.commit_hash }}
- name: Log in to the Container registry - name: Log in to the Container registry
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
@ -165,7 +179,7 @@ jobs:
- name: Build and push Database Docker image - name: Build and push Database Docker image
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
with: with:
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}} context: .
file: ./litellm-js/spend-logs/Dockerfile file: ./litellm-js/spend-logs/Dockerfile
push: true push: true
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }} tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
@ -176,6 +190,8 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.commit_hash }}
- name: Log in to the Container registry - name: Log in to the Container registry
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1

View file

@ -1,4 +1,19 @@
repos: repos:
- repo: local
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
- id: isort
name: isort
entry: isort
language: system
types: [python]
files: litellm/.*\.py
exclude: ^litellm/__init__.py$
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 24.2.0 rev: 24.2.0
hooks: hooks:
@ -16,11 +31,10 @@ repos:
name: Check if files match name: Check if files match
entry: python3 ci_cd/check_files_match.py entry: python3 ci_cd/check_files_match.py
language: system language: system
- repo: local # - id: check-file-length
hooks: # name: Check file length
- id: mypy # entry: python check_file_length.py
name: mypy # args: ["10000"] # set your desired maximum number of lines
entry: python3 -m mypy --ignore-missing-imports # language: python
language: system # files: litellm/.*\.py
types: [python] # exclude: ^litellm/tests/
files: ^litellm/

28
check_file_length.py Normal file
View file

@ -0,0 +1,28 @@
import sys
def check_file_length(max_lines, filenames):
bad_files = []
for filename in filenames:
with open(filename, "r") as file:
lines = file.readlines()
if len(lines) > max_lines:
bad_files.append((filename, len(lines)))
return bad_files
if __name__ == "__main__":
max_lines = int(sys.argv[1])
filenames = sys.argv[2:]
bad_files = check_file_length(max_lines, filenames)
if bad_files:
bad_files.sort(
key=lambda x: x[1], reverse=True
) # Sort files by length in descending order
for filename, length in bad_files:
print(f"{filename}: {length} lines")
sys.exit(1)
else:
sys.exit(0)

View file

@ -162,7 +162,7 @@ def completion(
- `function`: *object* - Required. - `function`: *object* - Required.
- `tool_choice`: *string or object (optional)* - Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces the model to call that function. - `tool_choice`: *string or object (optional)* - Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via `{"type: "function", "function": {"name": "my_function"}}` forces the model to call that function.
- `none` is the default when no functions are present. `auto` is the default if functions are present. - `none` is the default when no functions are present. `auto` is the default if functions are present.

View file

@ -1,90 +0,0 @@
import Image from '@theme/IdealImage';
import QueryParamReader from '../../src/components/queryParamReader.js'
# [Beta] Monitor Logs in Production
:::note
This is in beta. Expect frequent updates, as we improve based on your feedback.
:::
LiteLLM provides an integration to let you monitor logs in production.
👉 Jump to our sample LiteLLM Dashboard: https://admin.litellm.ai/
<Image img={require('../../img/alt_dashboard.png')} alt="Dashboard" />
## Debug your first logs
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_OpenAI.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
### 1. Get your LiteLLM Token
Go to [admin.litellm.ai](https://admin.litellm.ai/) and copy the code snippet with your unique token
<Image img={require('../../img/hosted_debugger_usage_page.png')} alt="Usage" />
### 2. Set up your environment
**Add it to your .env**
```python
import os
os.env["LITELLM_TOKEN"] = "e24c4c06-d027-4c30-9e78-18bc3a50aebb" # replace with your unique token
```
**Turn on LiteLLM Client**
```python
import litellm
litellm.client = True
```
### 3. Make a normal `completion()` call
```python
import litellm
from litellm import completion
import os
# set env variables
os.environ["LITELLM_TOKEN"] = "e24c4c06-d027-4c30-9e78-18bc3a50aebb" # replace with your unique token
os.environ["OPENAI_API_KEY"] = "openai key"
litellm.use_client = True # enable logging dashboard
messages = [{ "content": "Hello, how are you?","role": "user"}]
# openai call
response = completion(model="gpt-3.5-turbo", messages=messages)
```
Your `completion()` call print with a link to your session dashboard (https://admin.litellm.ai/<your_unique_token>)
In the above case it would be: [`admin.litellm.ai/e24c4c06-d027-4c30-9e78-18bc3a50aebb`](https://admin.litellm.ai/e24c4c06-d027-4c30-9e78-18bc3a50aebb)
Click on your personal dashboard link. Here's how you can find it 👇
<Image img={require('../../img/dash_output.png')} alt="Dashboard" />
[👋 Tell us if you need better privacy controls](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version?month=2023-08)
### 3. Review request log
Oh! Looks like our request was made successfully. Let's click on it and see exactly what got sent to the LLM provider.
Ah! So we can see that this request was made to a **Baseten** (see litellm_params > custom_llm_provider) for a model with ID - **7qQNLDB** (see model). The message sent was - `"Hey, how's it going?"` and the response received was - `"As an AI language model, I don't have feelings or emotions, but I can assist you with your queries. How can I assist you today?"`
<Image img={require('../../img/dashboard_log.png')} alt="Dashboard Log Row" />
:::info
🎉 Congratulations! You've successfully debugger your first log!
:::

View file

@ -2,6 +2,15 @@ import Image from '@theme/IdealImage';
# Athina # Athina
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Athina](https://athina.ai/) is an evaluation framework and production monitoring platform for your LLM-powered app. Athina is designed to enhance the performance and reliability of AI applications through real-time monitoring, granular analytics, and plug-and-play evaluations. [Athina](https://athina.ai/) is an evaluation framework and production monitoring platform for your LLM-powered app. Athina is designed to enhance the performance and reliability of AI applications through real-time monitoring, granular analytics, and plug-and-play evaluations.
<Image img={require('../../img/athina_dashboard.png')} /> <Image img={require('../../img/athina_dashboard.png')} />

View file

@ -1,5 +1,14 @@
# Greenscale - Track LLM Spend and Responsible Usage # Greenscale - Track LLM Spend and Responsible Usage
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII). [Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII).
## Getting Started ## Getting Started

View file

@ -1,4 +1,13 @@
# Helicone Tutorial # Helicone Tutorial
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Helicone](https://helicone.ai/) is an open source observability platform that proxies your OpenAI traffic and provides you key insights into your spend, latency and usage. [Helicone](https://helicone.ai/) is an open source observability platform that proxies your OpenAI traffic and provides you key insights into your spend, latency and usage.
## Use Helicone to log requests across all LLM Providers (OpenAI, Azure, Anthropic, Cohere, Replicate, PaLM) ## Use Helicone to log requests across all LLM Providers (OpenAI, Azure, Anthropic, Cohere, Replicate, PaLM)

View file

@ -1,6 +1,6 @@
import Image from '@theme/IdealImage'; import Image from '@theme/IdealImage';
# Langfuse - Logging LLM Input/Output # 🔥 Langfuse - Logging LLM Input/Output
LangFuse is open Source Observability & Analytics for LLM Apps LangFuse is open Source Observability & Analytics for LLM Apps
Detailed production traces and a granular view on quality, cost and latency Detailed production traces and a granular view on quality, cost and latency
@ -122,10 +122,12 @@ response = completion(
metadata={ metadata={
"generation_name": "ishaan-test-generation", # set langfuse Generation Name "generation_name": "ishaan-test-generation", # set langfuse Generation Name
"generation_id": "gen-id22", # set langfuse Generation ID "generation_id": "gen-id22", # set langfuse Generation ID
"parent_observation_id": "obs-id9" # set langfuse Parent Observation ID
"version": "test-generation-version" # set langfuse Generation Version "version": "test-generation-version" # set langfuse Generation Version
"trace_user_id": "user-id2", # set langfuse Trace User ID "trace_user_id": "user-id2", # set langfuse Trace User ID
"session_id": "session-1", # set langfuse Session ID "session_id": "session-1", # set langfuse Session ID
"tags": ["tag1", "tag2"], # set langfuse Tags "tags": ["tag1", "tag2"], # set langfuse Tags
"trace_name": "new-trace-name" # set langfuse Trace Name
"trace_id": "trace-id22", # set langfuse Trace ID "trace_id": "trace-id22", # set langfuse Trace ID
"trace_metadata": {"key": "value"}, # set langfuse Trace Metadata "trace_metadata": {"key": "value"}, # set langfuse Trace Metadata
"trace_version": "test-trace-version", # set langfuse Trace Version (if not set, defaults to Generation Version) "trace_version": "test-trace-version", # set langfuse Trace Version (if not set, defaults to Generation Version)
@ -147,9 +149,10 @@ print(response)
You can also pass `metadata` as part of the request header with a `langfuse_*` prefix: You can also pass `metadata` as part of the request header with a `langfuse_*` prefix:
```shell ```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \ curl --location --request POST 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \ --header 'Content-Type: application/json' \
--header 'langfuse_trace_id: trace-id22' \ --header 'Authorization: Bearer sk-1234' \
--header 'langfuse_trace_id: trace-id2' \
--header 'langfuse_trace_user_id: user-id2' \ --header 'langfuse_trace_user_id: user-id2' \
--header 'langfuse_trace_metadata: {"key":"value"}' \ --header 'langfuse_trace_metadata: {"key":"value"}' \
--data '{ --data '{
@ -192,7 +195,8 @@ The following parameters can be updated on a continuation of a trace by passing
* `generation_id` - Identifier for the generation, auto-generated by default * `generation_id` - Identifier for the generation, auto-generated by default
* `generation_name` - Identifier for the generation, auto-generated by default * `generation_name` - Identifier for the generation, auto-generated by default
* `prompt` - Langfuse prompt object used for the generation, defaults to None * `parent_observation_id` - Identifier for the parent observation, defaults to `None`
* `prompt` - Langfuse prompt object used for the generation, defaults to `None`
Any other key value pairs passed into the metadata not listed in the above spec for a `litellm` completion will be added as a metadata key value pair for the generation. Any other key value pairs passed into the metadata not listed in the above spec for a `litellm` completion will be added as a metadata key value pair for the generation.

View file

@ -1,6 +1,16 @@
import Image from '@theme/IdealImage'; import Image from '@theme/IdealImage';
# Langsmith - Logging LLM Input/Output # Langsmith - Logging LLM Input/Output
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
An all-in-one developer platform for every step of the application lifecycle An all-in-one developer platform for every step of the application lifecycle
https://smith.langchain.com/ https://smith.langchain.com/

View file

@ -1,6 +1,6 @@
import Image from '@theme/IdealImage'; import Image from '@theme/IdealImage';
# Logfire - Logging LLM Input/Output # 🔥 Logfire - Logging LLM Input/Output
Logfire is open Source Observability & Analytics for LLM Apps Logfire is open Source Observability & Analytics for LLM Apps
Detailed production traces and a granular view on quality, cost and latency Detailed production traces and a granular view on quality, cost and latency
@ -14,10 +14,14 @@ join our [discord](https://discord.gg/wuPM9dRgDw)
## Pre-Requisites ## Pre-Requisites
Ensure you have run `pip install logfire` for this integration Ensure you have installed the following packages to use this integration
```shell ```shell
pip install logfire litellm pip install litellm
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
``` ```
## Quick Start ## Quick Start
@ -25,8 +29,7 @@ pip install logfire litellm
Get your Logfire token from [Logfire](https://logfire.pydantic.dev/) Get your Logfire token from [Logfire](https://logfire.pydantic.dev/)
```python ```python
litellm.success_callback = ["logfire"] litellm.callbacks = ["logfire"]
litellm.failure_callback = ["logfire"] # logs errors to logfire
``` ```
```python ```python

View file

@ -1,5 +1,13 @@
# Lunary - Logging and tracing LLM input/output # Lunary - Logging and tracing LLM input/output
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Lunary](https://lunary.ai/) is an open-source AI developer platform providing observability, prompt management, and evaluation tools for AI developers. [Lunary](https://lunary.ai/) is an open-source AI developer platform providing observability, prompt management, and evaluation tools for AI developers.
<video controls width='900' > <video controls width='900' >

View file

@ -1,5 +1,16 @@
import Image from '@theme/IdealImage';
# Promptlayer Tutorial # Promptlayer Tutorial
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
Promptlayer is a platform for prompt engineers. Log OpenAI requests. Search usage history. Track performance. Visually manage prompt templates. Promptlayer is a platform for prompt engineers. Log OpenAI requests. Search usage history. Track performance. Visually manage prompt templates.
<Image img={require('../../img/promptlayer.png')} /> <Image img={require('../../img/promptlayer.png')} />

View file

@ -0,0 +1,46 @@
import Image from '@theme/IdealImage';
# Raw Request/Response Logging
See the raw request/response sent by LiteLLM in your logging provider (OTEL/Langfuse/etc.).
**on SDK**
```python
# pip install langfuse
import litellm
import os
# log raw request/response
litellm.log_raw_request_response = True
# from https://cloud.langfuse.com/
os.environ["LANGFUSE_PUBLIC_KEY"] = ""
os.environ["LANGFUSE_SECRET_KEY"] = ""
# Optional, defaults to https://cloud.langfuse.com
os.environ["LANGFUSE_HOST"] # optional
# LLM API Keys
os.environ['OPENAI_API_KEY']=""
# set langfuse as a callback, litellm will send the data to langfuse
litellm.success_callback = ["langfuse"]
# openai call
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Hi 👋 - i'm openai"}
]
)
```
**on Proxy**
```yaml
litellm_settings:
log_raw_request_response: True
```
**Expected Log**
<Image img={require('../../img/raw_request_log.png')}/>

View file

@ -1,5 +1,14 @@
import Image from '@theme/IdealImage'; import Image from '@theme/IdealImage';
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
# Sentry - Log LLM Exceptions # Sentry - Log LLM Exceptions
[Sentry](https://sentry.io/) provides error monitoring for production. LiteLLM can add breadcrumbs and send exceptions to Sentry with this integration [Sentry](https://sentry.io/) provides error monitoring for production. LiteLLM can add breadcrumbs and send exceptions to Sentry with this integration

View file

@ -1,4 +1,12 @@
# Supabase Tutorial # Supabase Tutorial
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Supabase](https://supabase.com/) is an open source Firebase alternative. [Supabase](https://supabase.com/) is an open source Firebase alternative.
Start your project with a Postgres database, Authentication, instant APIs, Edge Functions, Realtime subscriptions, Storage, and Vector embeddings. Start your project with a Postgres database, Authentication, instant APIs, Edge Functions, Realtime subscriptions, Storage, and Vector embeddings.

View file

@ -1,6 +1,16 @@
import Image from '@theme/IdealImage'; import Image from '@theme/IdealImage';
# Weights & Biases - Logging LLM Input/Output # Weights & Biases - Logging LLM Input/Output
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
Weights & Biases helps AI developers build better models faster https://wandb.ai Weights & Biases helps AI developers build better models faster https://wandb.ai
<Image img={require('../../img/wandb.png')} /> <Image img={require('../../img/wandb.png')} />

View file

@ -11,7 +11,7 @@ LiteLLM supports
:::info :::info
Anthropic API fails requests when `max_tokens` are not passed. Due to this litellm passes `max_tokens=4096` when no `max_tokens` are passed Anthropic API fails requests when `max_tokens` are not passed. Due to this litellm passes `max_tokens=4096` when no `max_tokens` are passed.
::: :::
@ -229,17 +229,6 @@ assert isinstance(
``` ```
### Setting `anthropic-beta` Header in Requests
Pass the the `extra_headers` param to litellm, All headers will be forwarded to Anthropic API
```python
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
)
```
### Forcing Anthropic Tool Use ### Forcing Anthropic Tool Use

View file

@ -68,6 +68,7 @@ response = litellm.completion(
| Model Name | Function Call | | Model Name | Function Call |
|------------------|----------------------------------------| |------------------|----------------------------------------|
| gpt-4o | `completion('azure/<your deployment name>', messages)` |
| gpt-4 | `completion('azure/<your deployment name>', messages)` | | gpt-4 | `completion('azure/<your deployment name>', messages)` |
| gpt-4-0314 | `completion('azure/<your deployment name>', messages)` | | gpt-4-0314 | `completion('azure/<your deployment name>', messages)` |
| gpt-4-0613 | `completion('azure/<your deployment name>', messages)` | | gpt-4-0613 | `completion('azure/<your deployment name>', messages)` |
@ -85,7 +86,8 @@ response = litellm.completion(
## Azure OpenAI Vision Models ## Azure OpenAI Vision Models
| Model Name | Function Call | | Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------| |-----------------------|-----------------------------------------------------------------|
| gpt-4-vision | `response = completion(model="azure/<your deployment name>", messages=messages)` | | gpt-4-vision | `completion(model="azure/<your deployment name>", messages=messages)` |
| gpt-4o | `completion('azure/<your deployment name>', messages)` |
#### Usage #### Usage
```python ```python

View file

@ -3,49 +3,151 @@ import TabItem from '@theme/TabItem';
# Azure AI Studio # Azure AI Studio
**Ensure the following:** LiteLLM supports all models on Azure AI Studio
1. The API Base passed ends in the `/v1/` prefix
example:
```python
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
```
2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
## Usage ## Usage
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
### ENV VAR
```python ```python
import litellm import os
response = litellm.completion( os.environ["AZURE_API_API_KEY"] = ""
model="azure/command-r-plus", os.environ["AZURE_AI_API_BASE"] = ""
api_base="<your-deployment-base>/v1/" ```
api_key="eskk******"
messages=[{"role": "user", "content": "What is the meaning of life?"}], ### Example Call
```python
from litellm import completion
import os
## set ENV variables
os.environ["AZURE_API_API_KEY"] = "azure ai key"
os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/
# predibase llama-3 call
response = completion(
model="azure_ai/command-r-plus",
messages = [{ "content": "Hello, how are you?","role": "user"}]
) )
``` ```
</TabItem> </TabItem>
<TabItem value="proxy" label="PROXY"> <TabItem value="proxy" label="PROXY">
## Sample Usage - LiteLLM Proxy
1. Add models to your config.yaml 1. Add models to your config.yaml
```yaml ```yaml
model_list: model_list:
- model_name: mistral
litellm_params:
model: azure/mistral-large-latest
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
api_key: JGbKodRcTp****
- model_name: command-r-plus - model_name: command-r-plus
litellm_params: litellm_params:
model: azure/command-r-plus model: azure_ai/command-r-plus
api_key: os.environ/AZURE_COHERE_API_KEY api_key: os.environ/AZURE_AI_API_KEY
api_base: os.environ/AZURE_COHERE_API_BASE api_base: os.environ/AZURE_AI_API_BASE
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="command-r-plus",
messages = [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
]
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "command-r-plus",
"messages": [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Passing additional params - max_tokens, temperature
See all litellm.completion supported params [here](../completion/input.md#translated-openai-params)
```python
# !pip install litellm
from litellm import completion
import os
## set ENV variables
os.environ["AZURE_AI_API_KEY"] = "azure ai api key"
os.environ["AZURE_AI_API_BASE"] = "azure ai api base"
# command r plus call
response = completion(
model="azure_ai/command-r-plus",
messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20,
temperature=0.5
)
```
**proxy**
```yaml
model_list:
- model_name: command-r-plus
litellm_params:
model: azure_ai/command-r-plus
api_key: os.environ/AZURE_AI_API_KEY
api_base: os.environ/AZURE_AI_API_BASE
max_tokens: 20
temperature: 0.5
``` ```
@ -103,9 +205,6 @@ response = litellm.completion(
</Tabs> </Tabs>
</TabItem>
</Tabs>
## Function Calling ## Function Calling
<Tabs> <Tabs>
@ -115,8 +214,8 @@ response = litellm.completion(
from litellm import completion from litellm import completion
# set env # set env
os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key" os.environ["AZURE_AI_API_KEY"] = "your-api-key"
os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base" os.environ["AZURE_AI_API_BASE"] = "your-api-base"
tools = [ tools = [
{ {
@ -141,9 +240,7 @@ tools = [
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion( response = completion(
model="azure/mistral-large-latest", model="azure_ai/mistral-large-latest",
api_base=os.getenv("AZURE_MISTRAL_API_BASE")
api_key=os.getenv("AZURE_MISTRAL_API_KEY")
messages=messages, messages=messages,
tools=tools, tools=tools,
tool_choice="auto", tool_choice="auto",
@ -206,10 +303,12 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## Supported Models ## Supported Models
LiteLLM supports **ALL** azure ai models. Here's a few examples:
| Model Name | Function Call | | Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| |--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` | | Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
| Cohere ommand-r | `completion(model="azure/command-r", messages)` | | Cohere command-r | `completion(model="azure/command-r", messages)` |
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` | | mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |

View file

@ -144,16 +144,135 @@ print(response)
</TabItem> </TabItem>
</Tabs> </Tabs>
## Set temperature, top p, etc.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=[{ "content": "Hello, how are you?","role": "user"}],
temperature=0.7,
top_p=1
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
**Set on yaml**
```yaml
model_list:
- model_name: bedrock-claude-v1
litellm_params:
model: bedrock/anthropic.claude-instant-v1
temperature: <your-temp>
top_p: <your-top-p>
```
**Set on request**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
temperature=0.7,
top_p=1
)
print(response)
```
</TabItem>
</Tabs>
## Pass provider-specific params
If you pass a non-openai param to litellm, we'll assume it's provider-specific and send it as a kwarg in the request body. [See more](../completion/input.md#provider-specific-params)
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=[{ "content": "Hello, how are you?","role": "user"}],
top_k=1 # 👈 PROVIDER-SPECIFIC PARAM
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
**Set on yaml**
```yaml
model_list:
- model_name: bedrock-claude-v1
litellm_params:
model: bedrock/anthropic.claude-instant-v1
top_k: 1 # 👈 PROVIDER-SPECIFIC PARAM
```
**Set on request**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="bedrock-claude-v1", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
temperature=0.7,
extra_body={
top_k=1 # 👈 PROVIDER-SPECIFIC PARAM
}
)
print(response)
```
</TabItem>
</Tabs>
## Usage - Function Calling ## Usage - Function Calling
:::info LiteLLM uses Bedrock's Converse API for making tool calls
Claude returns it's output as an XML Tree. [Here is how we translate it](https://github.com/BerriAI/litellm/blob/49642a5b00a53b1babc1a753426a8afcac85dbbe/litellm/llms/prompt_templates/factory.py#L734).
You can see the raw response via `response._hidden_params["original_response"]`.
Claude hallucinates, e.g. returning the list param `value` as `<value>\n<item>apple</item>\n<item>banana</item>\n</value>` or `<value>\n<list>\n<item>apple</item>\n<item>banana</item>\n</list>\n</value>`.
:::
```python ```python
from litellm import completion from litellm import completion
@ -361,47 +480,6 @@ response = completion(
) )
``` ```
### Passing an external BedrockRuntime.Client as a parameter - Completion()
Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth.
Create a client from session credentials:
```python
import boto3
from litellm import completion
bedrock = boto3.client(
service_name="bedrock-runtime",
region_name="us-east-1",
aws_access_key_id="",
aws_secret_access_key="",
aws_session_token="",
)
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_bedrock_client=bedrock,
)
```
Create a client from AWS profile in `~/.aws/config`:
```python
import boto3
from litellm import completion
dev_session = boto3.Session(profile_name="dev-profile")
bedrock = dev_session.client(
service_name="bedrock-runtime",
region_name="us-east-1",
)
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_bedrock_client=bedrock,
)
```
### SSO Login (AWS Profile) ### SSO Login (AWS Profile)
- Set `AWS_PROFILE` environment variable - Set `AWS_PROFILE` environment variable
- Make bedrock completion call - Make bedrock completion call
@ -464,6 +542,56 @@ response = completion(
) )
``` ```
### Passing an external BedrockRuntime.Client as a parameter - Completion()
:::warning
This is a deprecated flow. Boto3 is not async. And boto3.client does not let us make the http call through httpx. Pass in your aws params through the method above 👆. [See Auth Code](https://github.com/BerriAI/litellm/blob/55a20c7cce99a93d36a82bf3ae90ba3baf9a7f89/litellm/llms/bedrock_httpx.py#L284) [Add new auth flow](https://github.com/BerriAI/litellm/issues)
:::
Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth.
Create a client from session credentials:
```python
import boto3
from litellm import completion
bedrock = boto3.client(
service_name="bedrock-runtime",
region_name="us-east-1",
aws_access_key_id="",
aws_secret_access_key="",
aws_session_token="",
)
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_bedrock_client=bedrock,
)
```
Create a client from AWS profile in `~/.aws/config`:
```python
import boto3
from litellm import completion
dev_session = boto3.Session(profile_name="dev-profile")
bedrock = dev_session.client(
service_name="bedrock-runtime",
region_name="us-east-1",
)
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_bedrock_client=bedrock,
)
```
## Provisioned throughput models ## Provisioned throughput models
To use provisioned throughput Bedrock models pass To use provisioned throughput Bedrock models pass
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models) - `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)

View file

@ -1,10 +1,13 @@
# 🆕 Clarifai # Clarifai
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai. Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
:::warning
Streaming is not yet supported on using clarifai and litellm. Tracking support here: https://github.com/BerriAI/litellm/issues/4162
:::
## Pre-Requisites ## Pre-Requisites
`pip install clarifai`
`pip install litellm` `pip install litellm`
## Required Environment Variables ## Required Environment Variables
@ -12,6 +15,7 @@ To obtain your Clarifai Personal access token follow this [link](https://docs.cl
```python ```python
os.environ["CLARIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT os.environ["CLARIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT
``` ```
## Usage ## Usage
@ -68,7 +72,7 @@ Example Usage - Note: liteLLM supports all models deployed on Clarifai
| clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`| | clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`|
| clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` | | clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` |
## Mistal LLMs ## Mistral LLMs
| Model Name | Function Call | | Model Name | Function Call |
|---------------------------------------------|------------------------------------------------------------------------| |---------------------------------------------|------------------------------------------------------------------------|
| clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` | | clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` |

View file

@ -0,0 +1,255 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Codestral API [Mistral AI]
Codestral is available in select code-completion plugins but can also be queried directly. See the documentation for more details.
## API Key
```python
# env variable
os.environ['CODESTRAL_API_KEY']
```
## FIM / Completions
:::info
Official Mistral API Docs: https://docs.mistral.ai/api/#operation/createFIMCompletion
:::
<Tabs>
<TabItem value="no-streaming" label="No Streaming">
#### Sample Usage
```python
import os
import litellm
os.environ['CODESTRAL_API_KEY']
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True", # optional
temperature=0, # optional
top_p=1, # optional
max_tokens=10, # optional
min_tokens=10, # optional
seed=10, # optional
stop=["return"], # optional
)
```
#### Expected Response
```json
{
"id": "b41e0df599f94bc1a46ea9fcdbc2aabe",
"object": "text_completion",
"created": 1589478378,
"model": "codestral-latest",
"choices": [
{
"text": "\n assert is_odd(1)\n assert",
"index": 0,
"logprobs": null,
"finish_reason": "length"
}
],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
```
</TabItem>
<TabItem value="stream" label="Streaming">
#### Sample Usage - Streaming
```python
import os
import litellm
os.environ['CODESTRAL_API_KEY']
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True", # optional
temperature=0, # optional
top_p=1, # optional
stream=True,
seed=10, # optional
stop=["return"], # optional
)
async for chunk in response:
print(chunk)
```
#### Expected Response
```json
{
"id": "726025d3e2d645d09d475bb0d29e3640",
"object": "text_completion",
"created": 1718659669,
"choices": [
{
"text": "This",
"index": 0,
"logprobs": null,
"finish_reason": null
}
],
"model": "codestral-2405",
}
```
</TabItem>
</Tabs>
### Supported Models
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).
| Model Name | Function Call |
|----------------|--------------------------------------------------------------|
| Codestral Latest | `completion(model="text-completion-codestral/codestral-latest", messages)` |
| Codestral 2405 | `completion(model="text-completion-codestral/codestral-2405", messages)`|
## Chat Completions
:::info
Official Mistral API Docs: https://docs.mistral.ai/api/#operation/createChatCompletion
:::
<Tabs>
<TabItem value="no-streaming" label="No Streaming">
#### Sample Usage
```python
import os
import litellm
os.environ['CODESTRAL_API_KEY']
response = await litellm.acompletion(
model="codestral/codestral-latest",
messages=[
{
"role": "user",
"content": "Hey, how's it going?",
}
],
temperature=0.0, # optional
top_p=1, # optional
max_tokens=10, # optional
safe_prompt=False, # optional
seed=12, # optional
)
```
#### Expected Response
```json
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "codestral/codestral-latest",
"system_fingerprint": None,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "\n\nHello there, how may I assist you today?",
},
"logprobs": null,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
```
</TabItem>
<TabItem value="stream" label="Streaming">
#### Sample Usage - Streaming
```python
import os
import litellm
os.environ['CODESTRAL_API_KEY']
response = await litellm.acompletion(
model="codestral/codestral-latest",
messages=[
{
"role": "user",
"content": "Hey, how's it going?",
}
],
stream=True, # optional
temperature=0.0, # optional
top_p=1, # optional
max_tokens=10, # optional
safe_prompt=False, # optional
seed=12, # optional
)
async for chunk in response:
print(chunk)
```
#### Expected Response
```json
{
"id":"chatcmpl-123",
"object":"chat.completion.chunk",
"created":1694268190,
"model": "codestral/codestral-latest",
"system_fingerprint": None,
"choices":[
{
"index":0,
"delta":{"role":"assistant","content":"gm"},
"logprobs":null,
" finish_reason":null
}
]
}
```
</TabItem>
</Tabs>
### Supported Models
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).
| Model Name | Function Call |
|----------------|--------------------------------------------------------------|
| Codestral Latest | `completion(model="codestral/codestral-latest", messages)` |
| Codestral 2405 | `completion(model="codestral/codestral-2405", messages)`|

View file

@ -125,11 +125,12 @@ See all litellm.completion supported params [here](../completion/input.md#transl
from litellm import completion from litellm import completion
import os import os
## set ENV variables ## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key" os.environ["DATABRICKS_API_KEY"] = "databricks key"
os.environ["DATABRICKS_API_BASE"] = "databricks api base"
# predibae llama-3 call # databricks dbrx call
response = completion( response = completion(
model="predibase/llama3-8b-instruct", model="databricks/databricks-dbrx-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}], messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20, max_tokens=20,
temperature=0.5 temperature=0.5

View file

@ -1,6 +1,13 @@
# DeepInfra # DeepInfra
https://deepinfra.com/ https://deepinfra.com/
:::tip
**We support ALL DeepInfra models, just set `model=deepinfra/<any-model-on-deepinfra>` as a prefix when sending litellm requests**
:::
## API Key ## API Key
```python ```python
# env variable # env variable
@ -38,13 +45,11 @@ for chunk in response:
## Chat Models ## Chat Models
| Model Name | Function Call | | Model Name | Function Call |
|------------------|--------------------------------------| |------------------|--------------------------------------|
| meta-llama/Meta-Llama-3-8B-Instruct | `completion(model="deepinfra/meta-llama/Meta-Llama-3-8B-Instruct", messages)` |
| meta-llama/Meta-Llama-3-70B-Instruct | `completion(model="deepinfra/meta-llama/Meta-Llama-3-70B-Instruct", messages)` |
| meta-llama/Llama-2-70b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-70b-chat-hf", messages)` | | meta-llama/Llama-2-70b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-70b-chat-hf", messages)` |
| meta-llama/Llama-2-7b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-7b-chat-hf", messages)` | | meta-llama/Llama-2-7b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-7b-chat-hf", messages)` |
| meta-llama/Llama-2-13b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-13b-chat-hf", messages)` | | meta-llama/Llama-2-13b-chat-hf | `completion(model="deepinfra/meta-llama/Llama-2-13b-chat-hf", messages)` |
| codellama/CodeLlama-34b-Instruct-hf | `completion(model="deepinfra/codellama/CodeLlama-34b-Instruct-hf", messages)` | | codellama/CodeLlama-34b-Instruct-hf | `completion(model="deepinfra/codellama/CodeLlama-34b-Instruct-hf", messages)` |
| mistralai/Mistral-7B-Instruct-v0.1 | `completion(model="deepinfra/mistralai/Mistral-7B-Instruct-v0.1", messages)` | | mistralai/Mistral-7B-Instruct-v0.1 | `completion(model="deepinfra/mistralai/Mistral-7B-Instruct-v0.1", messages)` |
| jondurbin/airoboros-l2-70b-gpt4-1.4.1 | `completion(model="deepinfra/jondurbin/airoboros-l2-70b-gpt4-1.4.1", messages)` | | jondurbin/airoboros-l2-70b-gpt4-1.4.1 | `completion(model="deepinfra/jondurbin/airoboros-l2-70b-gpt4-1.4.1", messages)` |

View file

@ -45,6 +45,52 @@ response = completion(
) )
``` ```
## Tool Calling
```python
from litellm import completion
import os
# set env
os.environ["GEMINI_API_KEY"] = ".."
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
tools=tools,
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
```
# Gemini-Pro-Vision # Gemini-Pro-Vision
LiteLLM Supports the following image types passed in `url` LiteLLM Supports the following image types passed in `url`
- Images with direct links - https://storage.googleapis.com/github-repo/img/gemini/intro/landmark3.jpg - Images with direct links - https://storage.googleapis.com/github-repo/img/gemini/intro/landmark3.jpg

View file

@ -1,7 +1,11 @@
# Groq # Groq
https://groq.com/ https://groq.com/
**We support ALL Groq models, just set `groq/` as a prefix when sending completion requests** :::tip
**We support ALL Groq models, just set `model=groq/<any-model-on-groq>` as a prefix when sending litellm requests**
:::
## API Key ## API Key
```python ```python

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# OpenAI (Text Completion) # OpenAI (Text Completion)
LiteLLM supports OpenAI text completion models LiteLLM supports OpenAI text completion models

View file

@ -208,7 +208,7 @@ print(response)
Instead of using the `custom_llm_provider` arg to specify which provider you're using (e.g. together ai), you can just pass the provider name as part of the model name, and LiteLLM will parse it out. Instead of using the `custom_llm_provider` arg to specify which provider you're using (e.g. together ai), you can just pass the provider name as part of the model name, and LiteLLM will parse it out.
Expected format: <custom_llm_provider>/<model_name> Expected format: `<custom_llm_provider>/<model_name>`
e.g. completion(model="together_ai/togethercomputer/Llama-2-7B-32K-Instruct", ...) e.g. completion(model="together_ai/togethercomputer/Llama-2-7B-32K-Instruct", ...)

View file

@ -8,6 +8,152 @@ import TabItem from '@theme/TabItem';
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a> </a>
## 🆕 `vertex_ai_beta/` route
New `vertex_ai_beta/` route. Adds support for system messages, tool_choice params, etc. by moving to httpx client (instead of vertex sdk).
```python
from litellm import completion
import json
## GET CREDENTIALS
file_path = 'path/to/vertex_ai_service_account.json'
# Load the JSON file
with open(file_path, 'r') as file:
vertex_credentials = json.load(file)
# Convert to JSON string
vertex_credentials_json = json.dumps(vertex_credentials)
## COMPLETION CALL
response = completion(
model="vertex_ai_beta/gemini-pro",
messages=[{ "content": "Hello, how are you?","role": "user"}],
vertex_credentials=vertex_credentials_json
)
```
### **System Message**
```python
from litellm import completion
import json
## GET CREDENTIALS
file_path = 'path/to/vertex_ai_service_account.json'
# Load the JSON file
with open(file_path, 'r') as file:
vertex_credentials = json.load(file)
# Convert to JSON string
vertex_credentials_json = json.dumps(vertex_credentials)
response = completion(
model="vertex_ai_beta/gemini-pro",
messages=[{"content": "You are a good bot.","role": "system"}, {"content": "Hello, how are you?","role": "user"}],
vertex_credentials=vertex_credentials_json
)
```
### **Function Calling**
Force Gemini to make tool calls with `tool_choice="required"`.
```python
from litellm import completion
import json
## GET CREDENTIALS
file_path = 'path/to/vertex_ai_service_account.json'
# Load the JSON file
with open(file_path, 'r') as file:
vertex_credentials = json.load(file)
# Convert to JSON string
vertex_credentials_json = json.dumps(vertex_credentials)
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
},
}
]
data = {
"model": "vertex_ai_beta/gemini-1.5-pro-preview-0514"),
"messages": messages,
"tools": tools,
"tool_choice": "required",
"vertex_credentials": vertex_credentials_json
}
## COMPLETION CALL
print(completion(**data))
```
### **JSON Schema**
```python
from litellm import completion
## GET CREDENTIALS
file_path = 'path/to/vertex_ai_service_account.json'
# Load the JSON file
with open(file_path, 'r') as file:
vertex_credentials = json.load(file)
# Convert to JSON string
vertex_credentials_json = json.dumps(vertex_credentials)
messages = [
{
"role": "user",
"content": """
List 5 popular cookie recipes.
Using this JSON schema:
Recipe = {"recipe_name": str}
Return a `list[Recipe]`
"""
}
]
completion(model="vertex_ai_beta/gemini-1.5-flash-preview-0514", messages=messages, response_format={ "type": "json_object" })
```
## Pre-requisites ## Pre-requisites
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image) * `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
* Authentication: * Authentication:
@ -140,7 +286,7 @@ In certain use-cases you may need to make calls to the models and pass [safety s
```python ```python
response = completion( response = completion(
model="gemini/gemini-pro", model="vertex_ai/gemini-pro",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}] messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
safety_settings=[ safety_settings=[
{ {
@ -363,8 +509,8 @@ response = completion(
## Gemini 1.5 Pro (and Vision) ## Gemini 1.5 Pro (and Vision)
| Model Name | Function Call | | Model Name | Function Call |
|------------------|--------------------------------------| |------------------|--------------------------------------|
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` | | gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-1.5-pro', messages)` |
| gemini-1.5-flash-preview-0514 | `completion('gemini-1.5-flash-preview-0514', messages)`, `completion('vertex_ai/gemini-pro', messages)` | | gemini-1.5-flash-preview-0514 | `completion('gemini-1.5-flash-preview-0514', messages)`, `completion('vertex_ai/gemini-1.5-flash-preview-0514', messages)` |
| gemini-1.5-pro-preview-0514 | `completion('gemini-1.5-pro-preview-0514', messages)`, `completion('vertex_ai/gemini-1.5-pro-preview-0514', messages)` | | gemini-1.5-pro-preview-0514 | `completion('gemini-1.5-pro-preview-0514', messages)`, `completion('vertex_ai/gemini-1.5-pro-preview-0514', messages)` |
@ -449,6 +595,54 @@ print(response)
</TabItem> </TabItem>
</Tabs> </Tabs>
## Usage - Function Calling
LiteLLM supports Function Calling for Vertex AI gemini models.
```python
from litellm import completion
import os
# set env
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ".."
os.environ["VERTEX_AI_PROJECT"] = ".."
os.environ["VERTEX_AI_LOCATION"] = ".."
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="vertex_ai/gemini-pro-vision",
messages=messages,
tools=tools,
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
```
## Chat Models ## Chat Models
| Model Name | Function Call | | Model Name | Function Call |
@ -500,6 +694,8 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| Model Name | Function Call | | Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| |--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| text-embedding-004 | `embedding(model="vertex_ai/text-embedding-004", input)` |
| text-multilingual-embedding-002 | `embedding(model="vertex_ai/text-multilingual-embedding-002", input)` |
| textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` | | textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` |
| textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` | | textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` |
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` | | textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
@ -508,6 +704,29 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` | | text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` | | text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
### Advanced Use `task_type` and `title` (Vertex Specific Params)
👉 `task_type` and `title` are vertex specific params
LiteLLM Supported Vertex Specific Params
```python
auto_truncate: Optional[bool] = None
task_type: Optional[Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]] = None
title: Optional[str] = None # The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
```
**Example Usage with LiteLLM**
```python
response = litellm.embedding(
model="vertex_ai/text-embedding-004",
input=["good morning from litellm", "gm"]
task_type = "RETRIEVAL_DOCUMENT",
dimensions=1,
auto_truncate=True,
)
```
## Image Generation Models ## Image Generation Models
Usage Usage
@ -607,6 +826,3 @@ s/o @[Darien Kindlund](https://www.linkedin.com/in/kindlund/) for this tutorial

View file

@ -1,3 +1,5 @@
import Image from '@theme/IdealImage';
# 🚨 Alerting / Webhooks # 🚨 Alerting / Webhooks
Get alerts for: Get alerts for:
@ -15,6 +17,11 @@ Get alerts for:
- **Spend** Weekly & Monthly spend per Team, Tag - **Spend** Weekly & Monthly spend per Team, Tag
Works across:
- [Slack](#quick-start)
- [Discord](#advanced---using-discord-webhooks)
- [Microsoft Teams](#advanced---using-ms-teams-webhooks)
## Quick Start ## Quick Start
Set up a slack alert channel to receive alerts from proxy. Set up a slack alert channel to receive alerts from proxy.
@ -25,40 +32,32 @@ Get a slack webhook url from https://api.slack.com/messaging/webhooks
You can also use Discord Webhooks, see [here](#using-discord-webhooks) You can also use Discord Webhooks, see [here](#using-discord-webhooks)
### Step 2: Update config.yaml
- Set `SLACK_WEBHOOK_URL` in your proxy env to enable Slack alerts. Set `SLACK_WEBHOOK_URL` in your proxy env to enable Slack alerts.
- Just for testing purposes, let's save a bad key to our proxy.
```bash
export SLACK_WEBHOOK_URL="https://hooks.slack.com/services/<>/<>/<>"
```
### Step 2: Setup Proxy
```yaml ```yaml
model_list:
model_name: "azure-model"
litellm_params:
model: "azure/gpt-35-turbo"
api_key: "my-bad-key" # 👈 bad key
general_settings: general_settings:
alerting: ["slack"] alerting: ["slack"]
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+ alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
environment_variables:
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
SLACK_DAILY_REPORT_FREQUENCY: "86400" # 24 hours; Optional: defaults to 12 hours
``` ```
Start proxy
### Step 3: Start proxy
```bash ```bash
$ litellm --config /path/to/config.yaml $ litellm --config /path/to/config.yaml
``` ```
## Testing Alerting is Setup Correctly
Make a GET request to `/health/services`, expect to see a test slack alert in your provided webhook slack channel ### Step 3: Test it!
```shell
curl -X GET 'http://localhost:4000/health/services?service=slack' \ ```bash
curl -X GET 'http://0.0.0.0:4000/health/services?service=slack' \
-H 'Authorization: Bearer sk-1234' -H 'Authorization: Bearer sk-1234'
``` ```
@ -77,7 +76,34 @@ litellm_settings:
``` ```
## Advanced - Add Metadata to alerts
Add alerting metadata to proxy calls for debugging.
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [],
extra_body={
"metadata": {
"alerting_metadata": {
"hello": "world"
}
}
}
)
```
**Expected Response**
<Image img={require('../../img/alerting_metadata.png')}/>
## Advanced - Opting into specific alert types ## Advanced - Opting into specific alert types
@ -108,6 +134,48 @@ AlertType = Literal[
``` ```
## Advanced - Using MS Teams Webhooks
MS Teams provides a slack compatible webhook url that you can use for alerting
##### Quick Start
1. [Get a webhook url](https://learn.microsoft.com/en-us/microsoftteams/platform/webhooks-and-connectors/how-to/add-incoming-webhook?tabs=newteams%2Cdotnet#create-an-incoming-webhook) for your Microsoft Teams channel
2. Add it to your .env
```bash
SLACK_WEBHOOK_URL="https://berriai.webhook.office.com/webhookb2/...6901/IncomingWebhook/b55fa0c2a48647be8e6effedcd540266/e04b1092-4a3e-44a2-ab6b-29a0a4854d1d"
```
3. Add it to your litellm config
```yaml
model_list:
model_name: "azure-model"
litellm_params:
model: "azure/gpt-35-turbo"
api_key: "my-bad-key" # 👈 bad key
general_settings:
alerting: ["slack"]
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
```
4. Run health check!
Call the proxy `/health/services` endpoint to test if your alerting connection is correctly setup.
```bash
curl --location 'http://0.0.0.0:4000/health/services?service=slack' \
--header 'Authorization: Bearer sk-1234'
```
**Expected Response**
<Image img={require('../../img/ms_teams_alerting.png')}/>
## Advanced - Using Discord Webhooks ## Advanced - Using Discord Webhooks
Discord provides a slack compatible webhook url that you can use for alerting Discord provides a slack compatible webhook url that you can use for alerting
@ -139,7 +207,6 @@ environment_variables:
SLACK_WEBHOOK_URL: "https://discord.com/api/webhooks/1240030362193760286/cTLWt5ATn1gKmcy_982rl5xmYHsrM1IWJdmCL1AyOmU9JdQXazrp8L1_PYgUtgxj8x4f/slack" SLACK_WEBHOOK_URL: "https://discord.com/api/webhooks/1240030362193760286/cTLWt5ATn1gKmcy_982rl5xmYHsrM1IWJdmCL1AyOmU9JdQXazrp8L1_PYgUtgxj8x4f/slack"
``` ```
That's it ! You're ready to go !
## Advanced - [BETA] Webhooks for Budget Alerts ## Advanced - [BETA] Webhooks for Budget Alerts

View file

@ -252,6 +252,31 @@ $ litellm --config /path/to/config.yaml
``` ```
## Multiple OpenAI Organizations
Add all openai models across all OpenAI organizations with just 1 model definition
```yaml
- model_name: *
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
organization:
- org-1
- org-2
- org-3
```
LiteLLM will automatically create separate deployments for each org.
Confirm this via
```bash
curl --location 'http://0.0.0.0:4000/v1/model/info' \
--header 'Authorization: Bearer ${LITELLM_KEY}' \
--data ''
```
## Load Balancing ## Load Balancing
:::info :::info

View file

@ -428,3 +428,22 @@ model_list:
## Custom Input/Output Pricing ## Custom Input/Output Pricing
👉 Head to [Custom Input/Output Pricing](https://docs.litellm.ai/docs/proxy/custom_pricing) to setup custom pricing or your models 👉 Head to [Custom Input/Output Pricing](https://docs.litellm.ai/docs/proxy/custom_pricing) to setup custom pricing or your models
## ✨ Custom k,v pairs
Log specific key,value pairs as part of the metadata for a spend log
:::info
Logging specific key,value pairs in spend logs metadata is an enterprise feature. [See here](./enterprise.md#tracking-spend-with-custom-metadata)
:::
## ✨ Custom Tags
:::info
Tracking spend with Custom tags is an enterprise feature. [See here](./enterprise.md#tracking-spend-for-custom-tags)
:::

View file

@ -1,5 +1,6 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
import Image from '@theme/IdealImage';
# 🐳 Docker, Deploying LiteLLM Proxy # 🐳 Docker, Deploying LiteLLM Proxy
@ -26,7 +27,7 @@ docker-compose up
<Tabs> <Tabs>
<TabItem value="basic" label="Basic"> <TabItem value="basic" label="Basic (No DB)">
### Step 1. CREATE config.yaml ### Step 1. CREATE config.yaml
@ -97,7 +98,13 @@ docker run ghcr.io/berriai/litellm:main-latest --port 8002 --num_workers 8
``` ```
</TabItem> </TabItem>
<TabItem value="terraform" label="Terraform">
s/o [Nicholas Cecere](https://www.linkedin.com/in/nicholas-cecere-24243549/) for his LiteLLM User Management Terraform
👉 [Go here for Terraform](https://github.com/ncecere/terraform-litellm-user-mgmt)
</TabItem>
<TabItem value="base-image" label="use litellm as a base image"> <TabItem value="base-image" label="use litellm as a base image">
```shell ```shell
@ -379,6 +386,7 @@ kubectl port-forward service/litellm-service 4000:4000
Your OpenAI proxy server is now running on `http://0.0.0.0:4000`. Your OpenAI proxy server is now running on `http://0.0.0.0:4000`.
</TabItem> </TabItem>
<TabItem value="helm-deploy" label="Helm"> <TabItem value="helm-deploy" label="Helm">
@ -424,7 +432,6 @@ If you need to set your litellm proxy config.yaml, you can find this in [values.
</TabItem> </TabItem>
<TabItem value="helm-oci" label="Helm OCI Registry (GHCR)"> <TabItem value="helm-oci" label="Helm OCI Registry (GHCR)">
:::info :::info
@ -537,7 +544,9 @@ ghcr.io/berriai/litellm-database:main-latest --config your_config.yaml
## Advanced Deployment Settings ## Advanced Deployment Settings
### Customization of the server root path ### 1. Customization of the server root path (custom Proxy base url)
💥 Use this when you want to serve LiteLLM on a custom base url path like `https://localhost:4000/api/v1`
:::info :::info
@ -548,9 +557,29 @@ In a Kubernetes deployment, it's possible to utilize a shared DNS to host multip
Customize the root path to eliminate the need for employing multiple DNS configurations during deployment. Customize the root path to eliminate the need for employing multiple DNS configurations during deployment.
👉 Set `SERVER_ROOT_PATH` in your .env and this will be set as your server root path 👉 Set `SERVER_ROOT_PATH` in your .env and this will be set as your server root path
```
export SERVER_ROOT_PATH="/api/v1"
```
**Step 1. Run Proxy with `SERVER_ROOT_PATH` set in your env **
### Setting SSL Certification ```shell
docker run --name litellm-proxy \
-e DATABASE_URL=postgresql://<user>:<password>@<host>:<port>/<dbname> \
-e SERVER_ROOT_PATH="/api/v1" \
-p 4000:4000 \
ghcr.io/berriai/litellm-database:main-latest --config your_config.yaml
```
After running the proxy you can access it on `http://0.0.0.0:4000/api/v1/` (since we set `SERVER_ROOT_PATH="/api/v1"`)
**Step 2. Verify Running on correct path**
<Image img={require('../../img/custom_root_path.png')} />
**That's it**, that's all you need to run the proxy on a custom root path
### 2. Setting SSL Certification
Use this, If you need to set ssl certificates for your on prem litellm proxy Use this, If you need to set ssl certificates for your on prem litellm proxy
@ -646,7 +675,7 @@ Once the stack is created, get the DatabaseURL of the Database resource, copy th
#### 3. Connect to the EC2 Instance and deploy litellm on the EC2 container #### 3. Connect to the EC2 Instance and deploy litellm on the EC2 container
From the EC2 console, connect to the instance created by the stack (e.g., using SSH). From the EC2 console, connect to the instance created by the stack (e.g., using SSH).
Run the following command, replacing <database_url> with the value you copied in step 2 Run the following command, replacing `<database_url>` with the value you copied in step 2
```shell ```shell
docker run --name litellm-proxy \ docker run --name litellm-proxy \

View file

@ -5,6 +5,7 @@ import Image from '@theme/IdealImage';
Send an Email to your users when: Send an Email to your users when:
- A Proxy API Key is created for them - A Proxy API Key is created for them
- Their API Key crosses it's Budget - Their API Key crosses it's Budget
- All Team members of a LiteLLM Team -> when the team crosses it's budget
<Image img={require('../../img/email_notifs.png')} style={{ width: '500px' }}/> <Image img={require('../../img/email_notifs.png')} style={{ width: '500px' }}/>

View file

@ -205,6 +205,146 @@ curl -X GET "http://0.0.0.0:4000/spend/tags" \
``` ```
## Tracking Spend with custom metadata
Requirements:
- Virtual Keys & a database should be set up, see [virtual keys](https://docs.litellm.ai/docs/proxy/virtual_keys)
#### Usage - /chat/completions requests with special spend logs metadata
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
Set `extra_body={"metadata": { }}` to `metadata` you want to pass
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
extra_body={
"metadata": {
"spend_logs_metadata": {
"hello": "world"
}
}
}
)
print(response)
```
</TabItem>
<TabItem value="Curl" label="Curl Request">
Pass `metadata` as part of the request body
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"metadata": {
"spend_logs_metadata": {
"hello": "world"
}
}
}'
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000",
model = "gpt-3.5-turbo",
temperature=0.1,
extra_body={
"metadata": {
"spend_logs_metadata": {
"hello": "world"
}
}
}
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
#### Viewing Spend w/ custom metadata
#### `/spend/logs` Request Format
```bash
curl -X GET "http://0.0.0.0:4000/spend/logs?request_id=<your-call-id" \ # e.g.: chatcmpl-9ZKMURhVYSi9D6r6PJ9vLcayIK0Vm
-H "Authorization: Bearer sk-1234"
```
#### `/spend/logs` Response Format
```bash
[
{
"request_id": "chatcmpl-9ZKMURhVYSi9D6r6PJ9vLcayIK0Vm",
"call_type": "acompletion",
"metadata": {
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"user_api_key_alias": null,
"spend_logs_metadata": { # 👈 LOGGED CUSTOM METADATA
"hello": "world"
},
"user_api_key_team_id": null,
"user_api_key_user_id": "116544810872468347480",
"user_api_key_team_alias": null
},
}
]
```
## Enforce Required Params for LLM Requests ## Enforce Required Params for LLM Requests
Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params. Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params.

View file

@ -606,6 +606,52 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
** 🎉 Expect to see this trace logged in your OTEL collector** ** 🎉 Expect to see this trace logged in your OTEL collector**
### Context propagation across Services `Traceparent HTTP Header`
❓ Use this when you want to **pass information about the incoming request in a distributed tracing system**
✅ Key change: Pass the **`traceparent` header** in your requests. [Read more about traceparent headers here](https://uptrace.dev/opentelemetry/opentelemetry-traceparent.html#what-is-traceparent-header)
```curl
traceparent: 00-80e1afed08e019fc1110464cfa66635c-7a085853722dc6d2-01
```
Example Usage
1. Make Request to LiteLLM Proxy with `traceparent` header
```python
import openai
import uuid
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
example_traceparent = f"00-80e1afed08e019fc1110464cfa66635c-02e80198930058d4-01"
extra_headers = {
"traceparent": example_traceparent
}
_trace_id = example_traceparent.split("-")[1]
print("EXTRA HEADERS: ", extra_headers)
print("Trace ID: ", _trace_id)
response = client.chat.completions.create(
model="llama3",
messages=[
{"role": "user", "content": "this is a test request, write a short poem"}
],
extra_headers=extra_headers,
)
print(response)
```
```shell
# EXTRA HEADERS: {'traceparent': '00-80e1afed08e019fc1110464cfa66635c-02e80198930058d4-01'}
# Trace ID: 80e1afed08e019fc1110464cfa66635c
```
2. Lookup Trace ID on OTEL Logger
Search for Trace=`80e1afed08e019fc1110464cfa66635c` on your OTEL Collector
<Image img={require('../../img/otel_parent.png')} />

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Model Management # Model Management
Add new models + Get model info without restarting proxy. Add new models + Get model info without restarting proxy.

View file

@ -1,3 +1,5 @@
import Image from '@theme/IdealImage';
# LiteLLM Proxy Performance # LiteLLM Proxy Performance
### Throughput - 30% Increase ### Throughput - 30% Increase

View file

@ -1,4 +1,4 @@
# Grafana, Prometheus metrics [BETA] # 📈 Prometheus metrics [BETA]
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
@ -54,6 +54,13 @@ http://localhost:4000/metrics
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` | | `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` | | `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
### Budget Metrics
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_remaining_team_budget_metric` | Remaining Budget for Team (A team created on LiteLLM) |
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)|
## Monitor System Health ## Monitor System Health
To monitor the health of litellm adjacent services (redis / postgres), do: To monitor the health of litellm adjacent services (redis / postgres), do:

View file

@ -155,10 +155,8 @@ response = client.chat.completions.create(
} }
], ],
extra_body={ extra_body={
"metadata": {
"fallbacks": ["gpt-3.5-turbo"] "fallbacks": ["gpt-3.5-turbo"]
} }
}
) )
print(response) print(response)
@ -180,9 +178,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
"content": "what llm are you" "content": "what llm are you"
} }
], ],
"metadata": {
"fallbacks": ["gpt-3.5-turbo"] "fallbacks": ["gpt-3.5-turbo"]
}
}' }'
``` ```
</TabItem> </TabItem>
@ -204,10 +200,8 @@ chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000", openai_api_base="http://0.0.0.0:4000",
model="zephyr-beta", model="zephyr-beta",
extra_body={ extra_body={
"metadata": {
"fallbacks": ["gpt-3.5-turbo"] "fallbacks": ["gpt-3.5-turbo"]
} }
}
) )
messages = [ messages = [
@ -415,6 +409,28 @@ print(response)
</Tabs> </Tabs>
### Content Policy Fallbacks
Fallback across providers (e.g. from Azure OpenAI to Anthropic) if you hit content policy violation errors.
```yaml
model_list:
- model_name: gpt-3.5-turbo-small
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
- model_name: claude-opus
litellm_params:
model: claude-3-opus-20240229
api_key: os.environ/ANTHROPIC_API_KEY
litellm_settings:
content_policy_fallbacks: [{"gpt-3.5-turbo-small": ["claude-opus"]}]
```
### EU-Region Filtering (Pre-Call Checks) ### EU-Region Filtering (Pre-Call Checks)
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**. **Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.

View file

@ -124,3 +124,17 @@ LiteLLM Enterprise: Enable [SSO login](./ui.md#setup-ssoauth-for-ui)
<Image img={require('../../img/ui_self_serve_create_key.png')} style={{ width: '800px', height: 'auto' }} /> <Image img={require('../../img/ui_self_serve_create_key.png')} style={{ width: '800px', height: 'auto' }} />
## Advanced
### Setting custom logout URLs
Set `PROXY_LOGOUT_URL` in your .env if you want users to get redirected to a specific URL when they click logout
```
export PROXY_LOGOUT_URL="https://www.google.com"
```
<Image img={require('../../img/ui_logout.png')} style={{ width: '400px', height: 'auto' }} />

View file

@ -0,0 +1,123 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 💰 Setting Team Budgets
Track spend, set budgets for your Internal Team
## Setting Monthly Team Budgets
### 1. Create a team
- Set `max_budget=000000001` ($ value the team is allowed to spend)
- Set `budget_duration="1d"` (How frequently the budget should update)
Create a new team and set `max_budget` and `budget_duration`
```shell
curl -X POST 'http://0.0.0.0:4000/team/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"team_alias": "QA Prod Bot",
"max_budget": 0.000000001,
"budget_duration": "1d"
}'
```
Response
```shell
{
"team_alias": "QA Prod Bot",
"team_id": "de35b29e-6ca8-4f47-b804-2b79d07aa99a",
"max_budget": 0.0001,
"budget_duration": "1d",
"budget_reset_at": "2024-06-14T22:48:36.594000Z"
}
```
Possible values for `budget_duration`
| `budget_duration` | When Budget will reset |
| --- | --- |
| `budget_duration="1s"` | every 1 second |
| `budget_duration="1m"` | every 1 min |
| `budget_duration="1h"` | every 1 hour |
| `budget_duration="1d"` | every 1 day |
| `budget_duration="1mo"` | every 1 month |
### 2. Create a key for the `team`
Create a key for `team_id="de35b29e-6ca8-4f47-b804-2b79d07aa99a"` from Step 1
💡 **The Budget for Team="QA Prod Bot" budget will apply to this team**
```shell
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"team_id": "de35b29e-6ca8-4f47-b804-2b79d07aa99a"}'
```
Response
```shell
{"team_id":"de35b29e-6ca8-4f47-b804-2b79d07aa99a", "key":"sk-5qtncoYjzRcxMM4bDRktNQ"}
```
### 3. Test It
Use the key from step 2 and run this Request twice
```shell
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Authorization: Bearer sk-mso-JSykEGri86KyOvgxBw' \
-H 'Content-Type: application/json' \
-d ' {
"model": "llama3",
"messages": [
{
"role": "user",
"content": "hi"
}
]
}'
```
On the 2nd response - expect to see the following exception
```shell
{
"error": {
"message": "Budget has been exceeded! Current cost: 3.5e-06, Max budget: 1e-09",
"type": "auth_error",
"param": null,
"code": 400
}
}
```
## Advanced
### Prometheus metrics for `remaining_budget`
[More info about Prometheus metrics here](https://docs.litellm.ai/docs/proxy/prometheus)
You'll need the following in your proxy config.yaml
```yaml
litellm_settings:
success_callback: ["prometheus"]
failure_callback: ["prometheus"]
```
Expect to see this metric on prometheus to track the Remaining Budget for the team
```shell
litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e-6ca8-4f47-b804-2b79d07aa99a"} 9.699999999999992e-06
```

View file

@ -77,6 +77,28 @@ litellm_settings:
#### Step 2: Setup Oauth Client #### Step 2: Setup Oauth Client
<Tabs> <Tabs>
<TabItem value="okta" label="Okta SSO">
1. Add Okta credentials to your .env
```bash
GENERIC_CLIENT_ID = "<your-okta-client-id>"
GENERIC_CLIENT_SECRET = "<your-okta-client-secret>"
GENERIC_AUTHORIZATION_ENDPOINT = "<your-okta-domain>/authorize" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/authorize
GENERIC_TOKEN_ENDPOINT = "<your-okta-domain>/token" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/oauth/token
GENERIC_USERINFO_ENDPOINT = "<your-okta-domain>/userinfo" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/userinfo
```
You can get your domain specific auth/token/userinfo endpoints at `<YOUR-OKTA-DOMAIN>/.well-known/openid-configuration`
2. Add proxy url as callback_url on Okta
On Okta, add the 'callback_url' as `<proxy_base_url>/sso/callback`
<Image img={require('../../img/okta_callback_url.png')} />
</TabItem>
<TabItem value="google" label="Google SSO"> <TabItem value="google" label="Google SSO">
- Create a new Oauth 2.0 Client on https://console.cloud.google.com/ - Create a new Oauth 2.0 Client on https://console.cloud.google.com/
@ -115,7 +137,6 @@ MICROSOFT_TENANT="5a39737
</TabItem> </TabItem>
<TabItem value="Generic" label="Generic SSO Provider"> <TabItem value="Generic" label="Generic SSO Provider">
A generic OAuth client that can be used to quickly create support for any OAuth provider with close to no code A generic OAuth client that can be used to quickly create support for any OAuth provider with close to no code

View file

@ -63,7 +63,7 @@ You can:
- Add budgets to Teams - Add budgets to Teams
#### **Add budgets to users** #### **Add budgets to teams**
```shell ```shell
curl --location 'http://localhost:4000/team/new' \ curl --location 'http://localhost:4000/team/new' \
--header 'Authorization: Bearer <your-master-key>' \ --header 'Authorization: Bearer <your-master-key>' \
@ -102,6 +102,22 @@ curl --location 'http://localhost:4000/team/new' \
"budget_reset_at": null "budget_reset_at": null
} }
``` ```
#### **Add budget duration to teams**
`budget_duration`: Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
```
curl 'http://0.0.0.0:4000/team/new' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{
"team_alias": "my-new-team_4",
"members_with_roles": [{"role": "admin", "user_id": "5c4a0aa3-a1e1-43dc-bd87-3c2da8382a3a"}],
"budget_duration": 10s,
}'
```
</TabItem> </TabItem>
<TabItem value="per-team-member" label="For Team Members"> <TabItem value="per-team-member" label="For Team Members">
@ -397,6 +413,52 @@ curl 'http://0.0.0.0:4000/key/generate' \
</TabItem> </TabItem>
</Tabs> </Tabs>
### Reset Budgets
Reset budgets across keys/internal users/teams/customers
`budget_duration`: Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
<Tabs>
<TabItem value="users" label="Internal Users">
```bash
curl 'http://0.0.0.0:4000/user/new' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{
"max_budget": 10,
"budget_duration": 10s, # 👈 KEY CHANGE
}'
```
</TabItem>
<TabItem value="keys" label="Keys">
```bash
curl 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{
"max_budget": 10,
"budget_duration": 10s, # 👈 KEY CHANGE
}'
```
</TabItem>
<TabItem value="teams" label="Teams">
```bash
curl 'http://0.0.0.0:4000/team/new' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{
"max_budget": 10,
"budget_duration": 10s, # 👈 KEY CHANGE
}'
```
</TabItem>
</Tabs>
## Set Rate Limits ## Set Rate Limits
You can set: You can set:

View file

@ -790,84 +790,204 @@ If the error is a context window exceeded error, fall back to a larger model gro
Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc. Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc.
You can also set 'default_fallbacks', in case a specific model group is misconfigured / bad. You can also set `default_fallbacks`, in case a specific model group is misconfigured / bad.
There are 3 types of fallbacks:
- `content_policy_fallbacks`: For litellm.ContentPolicyViolationError - LiteLLM maps content policy violation errors across providers [**See Code**](https://github.com/BerriAI/litellm/blob/89a43c872a1e3084519fb9de159bf52f5447c6c4/litellm/utils.py#L8495C27-L8495C54)
- `context_window_fallbacks`: For litellm.ContextWindowExceededErrors - LiteLLM maps context window error messages across providers [**See Code**](https://github.com/BerriAI/litellm/blob/89a43c872a1e3084519fb9de159bf52f5447c6c4/litellm/utils.py#L8469)
- `fallbacks`: For all remaining errors - e.g. litellm.RateLimitError
**Content Policy Violation Fallback**
Key change:
```python
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}]
```
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
from litellm import Router from litellm import Router
router = Router(
model_list=[ model_list=[
{ # list of model deployments {
"model_name": "azure/gpt-3.5-turbo", # openai model name "model_name": "claude-2",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": {
"model": "azure/chatgpt-v-2", "model": "claude-2",
"api_key": "bad-key", "api_key": "",
"api_version": os.getenv("AZURE_API_VERSION"), "mock_response": Exception("content filtering policy"),
"api_base": os.getenv("AZURE_API_BASE")
}, },
"tpm": 240000,
"rpm": 1800
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
}, },
{ {
"model_name": "azure/gpt-3.5-turbo", # openai model name "model_name": "my-fallback-model",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": {
"model": "azure/chatgpt-functioncalling", "model": "claude-2",
"api_key": "bad-key", "api_key": "",
"api_version": os.getenv("AZURE_API_VERSION"), "mock_response": "This works!",
"api_base": os.getenv("AZURE_API_BASE")
}, },
"tpm": 240000,
"rpm": 1800
}, },
{ ],
"model_name": "gpt-3.5-turbo", # openai model name content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}], # 👈 KEY CHANGE
"litellm_params": { # params for litellm completion/embedding call # fallbacks=[..], # [OPTIONAL]
"model": "gpt-3.5-turbo", # context_window_fallbacks=[..], # [OPTIONAL]
"api_key": os.getenv("OPENAI_API_KEY"), )
},
"tpm": 1000000,
"rpm": 9000
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
]
response = router.completion(
router = Router(model_list=model_list, model="claude-2",
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
default_fallbacks=["gpt-3.5-turbo-16k"], )
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
set_verbose=True)
user_message = "Hello, whats the weather in San Francisco??"
messages = [{"content": user_message, "role": "user"}]
# normal fallback call
response = router.completion(model="azure/gpt-3.5-turbo", messages=messages)
# context window fallback call
response = router.completion(model="azure/gpt-3.5-turbo-context-fallback", messages=messages)
print(f"response: {response}")
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
In your proxy config.yaml just add this line 👇
```yaml
router_settings:
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}]
```
Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
</TabItem>
</Tabs>
**Context Window Exceeded Fallback**
Key change:
```python
context_window_fallbacks=[{"claude-2": ["my-fallback-model"]}]
```
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import Router
router = Router(
model_list=[
{
"model_name": "claude-2",
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": Exception("prompt is too long"),
},
},
{
"model_name": "my-fallback-model",
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": "This works!",
},
},
],
context_window_fallbacks=[{"claude-2": ["my-fallback-model"]}], # 👈 KEY CHANGE
# fallbacks=[..], # [OPTIONAL]
# content_policy_fallbacks=[..], # [OPTIONAL]
)
response = router.completion(
model="claude-2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
In your proxy config.yaml just add this line 👇
```yaml
router_settings:
context_window_fallbacks=[{"claude-2": ["my-fallback-model"]}]
```
Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
</TabItem>
</Tabs>
**Regular Fallbacks**
Key change:
```python
fallbacks=[{"claude-2": ["my-fallback-model"]}]
```
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import Router
router = Router(
model_list=[
{
"model_name": "claude-2",
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": Exception("this is a rate limit error"),
},
},
{
"model_name": "my-fallback-model",
"litellm_params": {
"model": "claude-2",
"api_key": "",
"mock_response": "This works!",
},
},
],
fallbacks=[{"claude-2": ["my-fallback-model"]}], # 👈 KEY CHANGE
# context_window_fallbacks=[..], # [OPTIONAL]
# content_policy_fallbacks=[..], # [OPTIONAL]
)
response = router.completion(
model="claude-2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
In your proxy config.yaml just add this line 👇
```yaml
router_settings:
fallbacks=[{"claude-2": ["my-fallback-model"]}]
```
Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
</TabItem>
</Tabs>
### Caching ### Caching

View file

@ -23,9 +23,13 @@ https://api.together.xyz/playground/chat?model=togethercomputer%2Fllama-2-70b-ch
model_name = "together_ai/togethercomputer/llama-2-70b-chat" model_name = "together_ai/togethercomputer/llama-2-70b-chat"
response = completion(model=model_name, messages=messages) response = completion(model=model_name, messages=messages)
print(response) print(response)
```
``` ```
{'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'role': 'assistant', 'content': "\n\nI'm not able to provide real-time weather information. However, I can suggest"}}], 'created': 1691629657.9288375, 'model': 'togethercomputer/llama-2-70b-chat', 'usage': {'prompt_tokens': 9, 'completion_tokens': 17, 'total_tokens': 26}} {'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'role': 'assistant', 'content': "\n\nI'm not able to provide real-time weather information. However, I can suggest"}}], 'created': 1691629657.9288375, 'model': 'togethercomputer/llama-2-70b-chat', 'usage': {'prompt_tokens': 9, 'completion_tokens': 17, 'total_tokens': 26}}
```
LiteLLM handles the prompt formatting for Together AI's Llama2 models as well, converting your message to the LiteLLM handles the prompt formatting for Together AI's Llama2 models as well, converting your message to the

View file

@ -138,8 +138,8 @@ const config = {
title: 'Docs', title: 'Docs',
items: [ items: [
{ {
label: 'Tutorial', label: 'Getting Started',
to: '/docs/index', to: 'https://docs.litellm.ai/docs/',
}, },
], ],
}, },

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 151 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 279 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 168 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

File diff suppressed because it is too large Load diff

View file

@ -23,8 +23,8 @@
"docusaurus": "^1.14.7", "docusaurus": "^1.14.7",
"docusaurus-lunr-search": "^2.4.1", "docusaurus-lunr-search": "^2.4.1",
"prism-react-renderer": "^1.3.5", "prism-react-renderer": "^1.3.5",
"react": "^17.0.2", "react": "^18.1.0",
"react-dom": "^17.0.2", "react-dom": "^18.1.0",
"sharp": "^0.32.6", "sharp": "^0.32.6",
"uuid": "^9.0.1" "uuid": "^9.0.1"
}, },

View file

@ -44,6 +44,7 @@ const sidebars = {
"proxy/self_serve", "proxy/self_serve",
"proxy/users", "proxy/users",
"proxy/customers", "proxy/customers",
"proxy/team_budgets",
"proxy/billing", "proxy/billing",
"proxy/user_keys", "proxy/user_keys",
"proxy/virtual_keys", "proxy/virtual_keys",
@ -54,6 +55,7 @@ const sidebars = {
items: ["proxy/logging", "proxy/streaming_logging"], items: ["proxy/logging", "proxy/streaming_logging"],
}, },
"proxy/ui", "proxy/ui",
"proxy/prometheus",
"proxy/email", "proxy/email",
"proxy/multiple_admins", "proxy/multiple_admins",
"proxy/team_based_routing", "proxy/team_based_routing",
@ -70,7 +72,6 @@ const sidebars = {
"proxy/pii_masking", "proxy/pii_masking",
"proxy/prompt_injection", "proxy/prompt_injection",
"proxy/caching", "proxy/caching",
"proxy/prometheus",
"proxy/call_hooks", "proxy/call_hooks",
"proxy/rules", "proxy/rules",
"proxy/cli", "proxy/cli",
@ -133,10 +134,11 @@ const sidebars = {
"providers/vertex", "providers/vertex",
"providers/palm", "providers/palm",
"providers/gemini", "providers/gemini",
"providers/mistral",
"providers/anthropic", "providers/anthropic",
"providers/aws_sagemaker", "providers/aws_sagemaker",
"providers/bedrock", "providers/bedrock",
"providers/mistral",
"providers/codestral",
"providers/cohere", "providers/cohere",
"providers/anyscale", "providers/anyscale",
"providers/huggingface", "providers/huggingface",
@ -170,10 +172,8 @@ const sidebars = {
"proxy/custom_pricing", "proxy/custom_pricing",
"routing", "routing",
"scheduler", "scheduler",
"rules",
"set_keys", "set_keys",
"budget_manager", "budget_manager",
"contributing",
"secret", "secret",
"completion/token_usage", "completion/token_usage",
"load_test", "load_test",
@ -181,10 +181,11 @@ const sidebars = {
type: "category", type: "category",
label: "Logging & Observability", label: "Logging & Observability",
items: [ items: [
"debugging/local_debugging",
"observability/callbacks",
"observability/custom_callback",
"observability/langfuse_integration", "observability/langfuse_integration",
"observability/logfire_integration",
"debugging/local_debugging",
"observability/raw_request_response",
"observability/custom_callback",
"observability/sentry", "observability/sentry",
"observability/lago", "observability/lago",
"observability/openmeter", "observability/openmeter",
@ -222,14 +223,16 @@ const sidebars = {
}, },
{ {
type: "category", type: "category",
label: "LangChain, LlamaIndex Integration", label: "LangChain, LlamaIndex, Instructor Integration",
items: ["langchain/langchain"], items: ["langchain/langchain", "tutorials/instructor"],
}, },
{ {
type: "category", type: "category",
label: "Extras", label: "Extras",
items: [ items: [
"extras/contributing", "extras/contributing",
"contributing",
"rules",
"proxy_server", "proxy_server",
{ {
type: "category", type: "category",

File diff suppressed because it is too large Load diff

View file

@ -93,7 +93,7 @@ class _ENTERPRISE_BannedKeywords(CustomLogger):
response.choices[0], litellm.utils.Choices response.choices[0], litellm.utils.Choices
): ):
for word in self.banned_keywords_list: for word in self.banned_keywords_list:
self.test_violation(test_str=response.choices[0].message.content) self.test_violation(test_str=response.choices[0].message.content or "")
async def async_post_call_streaming_hook( async def async_post_call_streaming_hook(
self, self,

View file

@ -122,236 +122,6 @@ async def ui_get_spend_by_tags(
return {"spend_per_tag": ui_tags} return {"spend_per_tag": ui_tags}
async def view_spend_logs_from_clickhouse(
api_key=None, user_id=None, request_id=None, start_date=None, end_date=None
):
verbose_logger.debug("Reading logs from Clickhouse")
import os
# if user has setup clickhouse
# TODO: Move this to be a helper function
# querying clickhouse for this data
import clickhouse_connect
from datetime import datetime
port = os.getenv("CLICKHOUSE_PORT")
if port is not None and isinstance(port, str):
port = int(port)
client = clickhouse_connect.get_client(
host=os.getenv("CLICKHOUSE_HOST"),
port=port,
username=os.getenv("CLICKHOUSE_USERNAME", ""),
password=os.getenv("CLICKHOUSE_PASSWORD", ""),
)
if (
start_date is not None
and isinstance(start_date, str)
and end_date is not None
and isinstance(end_date, str)
):
# Convert the date strings to datetime objects
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
# get top spend per day
response = client.query(
f"""
SELECT
toDate(startTime) AS day,
sum(spend) AS total_spend
FROM
spend_logs
WHERE
toDate(startTime) BETWEEN toDate('2024-02-01') AND toDate('2024-02-29')
GROUP BY
day
ORDER BY
total_spend
"""
)
results = []
result_rows = list(response.result_rows)
for response in result_rows:
current_row = {}
current_row["users"] = {"example": 0.0}
current_row["models"] = {}
current_row["spend"] = float(response[1])
current_row["startTime"] = str(response[0])
# stubbed api_key
current_row[""] = 0.0 # type: ignore
results.append(current_row)
return results
else:
# check if spend logs exist, if it does then return last 10 logs, sorted in descending order of startTime
response = client.query(
"""
SELECT
*
FROM
default.spend_logs
ORDER BY
startTime DESC
LIMIT
10
"""
)
# get size of spend logs
num_rows = client.query("SELECT count(*) FROM default.spend_logs")
num_rows = num_rows.result_rows[0][0]
# safely access num_rows.result_rows[0][0]
if num_rows is None:
num_rows = 0
raw_rows = list(response.result_rows)
response_data = {
"logs": raw_rows,
"log_count": num_rows,
}
return response_data
def _create_clickhouse_material_views(client=None, table_names=[]):
# Create Materialized Views if they don't exist
# Materialized Views send new inserted rows to the aggregate tables
verbose_logger.debug("Clickhouse: Creating Materialized Views")
if "daily_aggregated_spend_per_model_mv" not in table_names:
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_model_mv")
client.command(
"""
CREATE MATERIALIZED VIEW daily_aggregated_spend_per_model_mv
TO daily_aggregated_spend_per_model
AS
SELECT
toDate(startTime) as day,
sumState(spend) AS DailySpend,
model as model
FROM spend_logs
GROUP BY
day, model
"""
)
if "daily_aggregated_spend_per_api_key_mv" not in table_names:
verbose_logger.debug(
"Clickhouse: Creating daily_aggregated_spend_per_api_key_mv"
)
client.command(
"""
CREATE MATERIALIZED VIEW daily_aggregated_spend_per_api_key_mv
TO daily_aggregated_spend_per_api_key
AS
SELECT
toDate(startTime) as day,
sumState(spend) AS DailySpend,
api_key as api_key
FROM spend_logs
GROUP BY
day, api_key
"""
)
if "daily_aggregated_spend_per_user_mv" not in table_names:
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_user_mv")
client.command(
"""
CREATE MATERIALIZED VIEW daily_aggregated_spend_per_user_mv
TO daily_aggregated_spend_per_user
AS
SELECT
toDate(startTime) as day,
sumState(spend) AS DailySpend,
user as user
FROM spend_logs
GROUP BY
day, user
"""
)
if "daily_aggregated_spend_mv" not in table_names:
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_mv")
client.command(
"""
CREATE MATERIALIZED VIEW daily_aggregated_spend_mv
TO daily_aggregated_spend
AS
SELECT
toDate(startTime) as day,
sumState(spend) AS DailySpend
FROM spend_logs
GROUP BY
day
"""
)
def _create_clickhouse_aggregate_tables(client=None, table_names=[]):
# Basic Logging works without this - this is only used for low latency reporting apis
verbose_logger.debug("Clickhouse: Creating Aggregate Tables")
# Create Aggregeate Tables if they don't exist
if "daily_aggregated_spend_per_model" not in table_names:
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_model")
client.command(
"""
CREATE TABLE daily_aggregated_spend_per_model
(
`day` Date,
`DailySpend` AggregateFunction(sum, Float64),
`model` String
)
ENGINE = SummingMergeTree()
ORDER BY (day, model);
"""
)
if "daily_aggregated_spend_per_api_key" not in table_names:
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_api_key")
client.command(
"""
CREATE TABLE daily_aggregated_spend_per_api_key
(
`day` Date,
`DailySpend` AggregateFunction(sum, Float64),
`api_key` String
)
ENGINE = SummingMergeTree()
ORDER BY (day, api_key);
"""
)
if "daily_aggregated_spend_per_user" not in table_names:
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend_per_user")
client.command(
"""
CREATE TABLE daily_aggregated_spend_per_user
(
`day` Date,
`DailySpend` AggregateFunction(sum, Float64),
`user` String
)
ENGINE = SummingMergeTree()
ORDER BY (day, user);
"""
)
if "daily_aggregated_spend" not in table_names:
verbose_logger.debug("Clickhouse: Creating daily_aggregated_spend")
client.command(
"""
CREATE TABLE daily_aggregated_spend
(
`day` Date,
`DailySpend` AggregateFunction(sum, Float64),
)
ENGINE = SummingMergeTree()
ORDER BY (day);
"""
)
return
def _forecast_daily_cost(data: list): def _forecast_daily_cost(data: list):
import requests # type: ignore import requests # type: ignore
from datetime import datetime, timedelta from datetime import datetime, timedelta

View file

@ -13,7 +13,10 @@ from litellm._logging import (
verbose_logger, verbose_logger,
json_logs, json_logs,
_turn_on_json, _turn_on_json,
log_level,
) )
from litellm.proxy._types import ( from litellm.proxy._types import (
KeyManagementSystem, KeyManagementSystem,
KeyManagementSettings, KeyManagementSettings,
@ -34,7 +37,7 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = []
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter"] _custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"]
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
_langfuse_default_tags: Optional[ _langfuse_default_tags: Optional[
List[ List[
@ -60,6 +63,7 @@ _async_failure_callback: List[Callable] = (
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False turn_off_message_logging: Optional[bool] = False
log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False redact_messages_in_exceptions: Optional[bool] = False
store_audit_logs = False # Enterprise feature, allow users to see audit logs store_audit_logs = False # Enterprise feature, allow users to see audit logs
## end of callbacks ############# ## end of callbacks #############
@ -72,7 +76,7 @@ token: Optional[str] = (
) )
telemetry = True telemetry = True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
drop_params = False drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
modify_params = False modify_params = False
retry = True retry = True
### AUTH ### ### AUTH ###
@ -239,6 +243,7 @@ num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 0 allowed_fails: int = 0
num_retries_per_request: Optional[int] = ( num_retries_per_request: Optional[int] = (
None # for the request overall (incl. fallbacks + model retries) None # for the request overall (incl. fallbacks + model retries)
@ -336,6 +341,7 @@ bedrock_models: List = []
deepinfra_models: List = [] deepinfra_models: List = []
perplexity_models: List = [] perplexity_models: List = []
watsonx_models: List = [] watsonx_models: List = []
gemini_models: List = []
for key, value in model_cost.items(): for key, value in model_cost.items():
if value.get("litellm_provider") == "openai": if value.get("litellm_provider") == "openai":
open_ai_chat_completion_models.append(key) open_ai_chat_completion_models.append(key)
@ -382,13 +388,16 @@ for key, value in model_cost.items():
perplexity_models.append(key) perplexity_models.append(key)
elif value.get("litellm_provider") == "watsonx": elif value.get("litellm_provider") == "watsonx":
watsonx_models.append(key) watsonx_models.append(key)
elif value.get("litellm_provider") == "gemini":
gemini_models.append(key)
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
openai_compatible_endpoints: List = [ openai_compatible_endpoints: List = [
"api.perplexity.ai", "api.perplexity.ai",
"api.endpoints.anyscale.com/v1", "api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai", "api.deepinfra.com/v1/openai",
"api.mistral.ai/v1", "api.mistral.ai/v1",
"codestral.mistral.ai/v1/chat/completions",
"codestral.mistral.ai/v1/fim/completions",
"api.groq.com/openai/v1", "api.groq.com/openai/v1",
"api.deepseek.com/v1", "api.deepseek.com/v1",
"api.together.xyz/v1", "api.together.xyz/v1",
@ -399,12 +408,14 @@ openai_compatible_providers: List = [
"anyscale", "anyscale",
"mistral", "mistral",
"groq", "groq",
"codestral",
"deepseek", "deepseek",
"deepinfra", "deepinfra",
"perplexity", "perplexity",
"xinference", "xinference",
"together_ai", "together_ai",
"fireworks_ai", "fireworks_ai",
"azure_ai",
] ]
@ -588,6 +599,7 @@ model_list = (
+ maritalk_models + maritalk_models
+ vertex_language_models + vertex_language_models
+ watsonx_models + watsonx_models
+ gemini_models
) )
provider_list: List = [ provider_list: List = [
@ -603,12 +615,14 @@ provider_list: List = [
"together_ai", "together_ai",
"openrouter", "openrouter",
"vertex_ai", "vertex_ai",
"vertex_ai_beta",
"palm", "palm",
"gemini", "gemini",
"ai21", "ai21",
"baseten", "baseten",
"azure", "azure",
"azure_text", "azure_text",
"azure_ai",
"sagemaker", "sagemaker",
"bedrock", "bedrock",
"vllm", "vllm",
@ -622,6 +636,8 @@ provider_list: List = [
"anyscale", "anyscale",
"mistral", "mistral",
"groq", "groq",
"codestral",
"text-completion-codestral",
"deepseek", "deepseek",
"maritalk", "maritalk",
"voyage", "voyage",
@ -658,6 +674,7 @@ models_by_provider: dict = {
"perplexity": perplexity_models, "perplexity": perplexity_models,
"maritalk": maritalk_models, "maritalk": maritalk_models,
"watsonx": watsonx_models, "watsonx": watsonx_models,
"gemini": gemini_models,
} }
# mapping for those models which have larger equivalents # mapping for those models which have larger equivalents
@ -710,6 +727,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
from .timeout import timeout from .timeout import timeout
from .cost_calculator import completion_cost from .cost_calculator import completion_cost
from litellm.litellm_core_utils.litellm_logging import Logging
from .utils import ( from .utils import (
client, client,
exception_type, exception_type,
@ -718,12 +736,11 @@ from .utils import (
token_counter, token_counter,
create_pretrained_tokenizer, create_pretrained_tokenizer,
create_tokenizer, create_tokenizer,
cost_per_token,
supports_function_calling, supports_function_calling,
supports_parallel_function_calling, supports_parallel_function_calling,
supports_vision, supports_vision,
supports_system_messages,
get_litellm_params, get_litellm_params,
Logging,
acreate, acreate,
get_model_list, get_model_list,
get_max_tokens, get_max_tokens,
@ -743,9 +760,10 @@ from .utils import (
get_first_chars_messages, get_first_chars_messages,
ModelResponse, ModelResponse,
ImageResponse, ImageResponse,
ImageObject,
get_provider_fields, get_provider_fields,
) )
from .types.utils import ImageObject
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
@ -762,7 +780,8 @@ from .llms.gemini import GeminiConfig
from .llms.nlp_cloud import NLPCloudConfig from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig from .llms.petals import PetalsConfig
from .llms.vertex_ai import VertexAIConfig from .llms.vertex_httpx import VertexGeminiConfig
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
@ -784,8 +803,11 @@ from .llms.openai import (
OpenAIConfig, OpenAIConfig,
OpenAITextCompletionConfig, OpenAITextCompletionConfig,
MistralConfig, MistralConfig,
MistralEmbeddingConfig,
DeepInfraConfig, DeepInfraConfig,
AzureAIStudioConfig,
) )
from .llms.text_completion_codestral import MistralTextCompletionConfig
from .llms.azure import ( from .llms.azure import (
AzureOpenAIConfig, AzureOpenAIConfig,
AzureOpenAIError, AzureOpenAIError,
@ -819,4 +841,4 @@ from .router import Router
from .assistants.main import * from .assistants.main import *
from .batches.main import * from .batches.main import *
from .scheduler import * from .scheduler import *
from .cost_calculator import response_cost_calculator from .cost_calculator import response_cost_calculator, cost_per_token

View file

@ -1,12 +1,21 @@
import logging, os, json import json
from logging import Formatter import logging
import os
import traceback import traceback
from logging import Formatter
set_verbose = False set_verbose = False
if set_verbose is True:
logging.warning(
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
)
json_logs = bool(os.getenv("JSON_LOGS", False)) json_logs = bool(os.getenv("JSON_LOGS", False))
# Create a handler for the logger (you may need to adapt this based on your needs) # Create a handler for the logger (you may need to adapt this based on your needs)
log_level = os.getenv("LITELLM_LOG", "DEBUG")
numeric_level: str = getattr(logging, log_level.upper())
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG) handler.setLevel(numeric_level)
class JsonFormatter(Formatter): class JsonFormatter(Formatter):
@ -14,8 +23,12 @@ class JsonFormatter(Formatter):
super(JsonFormatter, self).__init__() super(JsonFormatter, self).__init__()
def format(self, record): def format(self, record):
json_record = {} json_record = {
json_record["message"] = record.getMessage() "message": record.getMessage(),
"level": record.levelname,
"timestamp": self.formatTime(record, self.datefmt),
}
return json.dumps(json_record) return json.dumps(json_record)

View file

@ -1192,7 +1192,7 @@ class S3Cache(BaseCache):
return cached_response return cached_response
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey": if e.response["Error"]["Code"] == "NoSuchKey":
verbose_logger.error( verbose_logger.debug(
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket." f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
) )
return None return None

View file

@ -1,21 +1,252 @@
# What is this? # What is this?
## File for 'response_cost' calculation in Logging ## File for 'response_cost' calculation in Logging
from typing import Optional, Union, Literal, List import time
from typing import List, Literal, Optional, Tuple, Union
import litellm
import litellm._logging import litellm._logging
from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_token as google_cost_per_token,
)
from litellm.utils import ( from litellm.utils import (
ModelResponse, CallTypes,
CostPerToken,
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
TranscriptionResponse, ModelResponse,
TextCompletionResponse, TextCompletionResponse,
CallTypes, TranscriptionResponse,
cost_per_token,
print_verbose, print_verbose,
CostPerToken,
token_counter, token_counter,
) )
import litellm
from litellm import verbose_logger
def _cost_per_token_custom_pricing_helper(
prompt_tokens=0,
completion_tokens=0,
response_time_ms=None,
### CUSTOM PRICING ###
custom_cost_per_token: Optional[CostPerToken] = None,
custom_cost_per_second: Optional[float] = None,
) -> Optional[Tuple[float, float]]:
"""Internal helper function for calculating cost, if custom pricing given"""
if custom_cost_per_token is None and custom_cost_per_second is None:
return None
if custom_cost_per_token is not None:
input_cost = custom_cost_per_token["input_cost_per_token"] * prompt_tokens
output_cost = custom_cost_per_token["output_cost_per_token"] * completion_tokens
return input_cost, output_cost
elif custom_cost_per_second is not None:
output_cost = custom_cost_per_second * response_time_ms / 1000 # type: ignore
return 0, output_cost
return None
def cost_per_token(
model: str = "",
prompt_tokens: float = 0,
completion_tokens: float = 0,
response_time_ms=None,
custom_llm_provider: Optional[str] = None,
region_name=None,
### CUSTOM PRICING ###
custom_cost_per_token: Optional[CostPerToken] = None,
custom_cost_per_second: Optional[float] = None,
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Parameters:
model (str): The name of the model to use. Default is ""
prompt_tokens (int): The number of tokens in the prompt.
completion_tokens (int): The number of tokens in the completion.
response_time (float): The amount of time, in milliseconds, it took the call to complete.
custom_llm_provider (str): The llm provider to whom the call was made (see init.py for full list)
custom_cost_per_token: Optional[CostPerToken]: the cost per input + output token for the llm api call.
custom_cost_per_second: Optional[float]: the cost per second for the llm api call.
Returns:
tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively.
"""
args = locals()
if model is None:
raise Exception("Invalid arg. Model cannot be none.")
## CUSTOM PRICING ##
response_cost = _cost_per_token_custom_pricing_helper(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
response_time_ms=response_time_ms,
custom_cost_per_second=custom_cost_per_second,
custom_cost_per_token=custom_cost_per_token,
)
if response_cost is not None:
return response_cost[0], response_cost[1]
# given
prompt_tokens_cost_usd_dollar: float = 0
completion_tokens_cost_usd_dollar: float = 0
model_cost_ref = litellm.model_cost
model_with_provider = model
if custom_llm_provider is not None:
model_with_provider = custom_llm_provider + "/" + model
if region_name is not None:
model_with_provider_and_region = (
f"{custom_llm_provider}/{region_name}/{model}"
)
if (
model_with_provider_and_region in model_cost_ref
): # use region based pricing, if it's available
model_with_provider = model_with_provider_and_region
else:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
model_without_prefix = model
model_parts = model.split("/")
if len(model_parts) > 1:
model_without_prefix = model_parts[1]
else:
model_without_prefix = model
"""
Code block that formats model to lookup in litellm.model_cost
Option1. model = "bedrock/ap-northeast-1/anthropic.claude-instant-v1". This is the most accurate since it is region based. Should always be option 1
Option2. model = "openai/gpt-4" - model = provider/model
Option3. model = "anthropic.claude-3" - model = model
"""
if (
model_with_provider in model_cost_ref
): # Option 2. use model with provider, model = "openai/gpt-4"
model = model_with_provider
elif model in model_cost_ref: # Option 1. use model passed, model="gpt-4"
model = model
elif (
model_without_prefix in model_cost_ref
): # Option 3. if user passed model="bedrock/anthropic.claude-3", use model="anthropic.claude-3"
model = model_without_prefix
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
print_verbose(f"Looking up model={model} in model_cost_map")
if custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini":
return google_cost_per_token(
model=model_without_prefix,
custom_llm_provider=custom_llm_provider,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
elif model in model_cost_ref:
print_verbose(f"Success: model={model} in model_cost_map")
print_verbose(
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"
)
if (
model_cost_ref[model].get("input_cost_per_token", None) is not None
and model_cost_ref[model].get("output_cost_per_token", None) is not None
):
## COST PER TOKEN ##
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
elif (
model_cost_ref[model].get("output_cost_per_second", None) is not None
and response_time_ms is not None
):
print_verbose(
f"For model={model} - output_cost_per_second: {model_cost_ref[model].get('output_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_second"]
* response_time_ms
/ 1000
)
elif (
model_cost_ref[model].get("input_cost_per_second", None) is not None
and response_time_ms is not None
):
print_verbose(
f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
)
completion_tokens_cost_usd_dollar = 0.0
print_verbose(
f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:gpt-3.5-turbo" in model:
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
prompt_tokens_cost_usd_dollar = (
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"]
* completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:davinci-002" in model:
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:davinci-002:abcd-id-cool-litellm
prompt_tokens_cost_usd_dollar = (
model_cost_ref["ft:davinci-002"]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref["ft:davinci-002"]["output_cost_per_token"]
* completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:babbage-002" in model:
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:babbage-002:abcd-id-cool-litellm
prompt_tokens_cost_usd_dollar = (
model_cost_ref["ft:babbage-002"]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref["ft:babbage-002"]["output_cost_per_token"]
* completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_llms:
verbose_logger.debug(f"Cost Tracking: {model} is an Azure LLM")
model = litellm.azure_llms[model]
verbose_logger.debug(
f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_embedding_models:
verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model")
model = litellm.azure_embedding_models[model]
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
else:
# if model is not in model_prices_and_context_window.json. Raise an exception-let users know
error_str = f"Model not in model_prices_and_context_window.json. You passed model={model}. Register pricing for model - https://docs.litellm.ai/docs/proxy/custom_pricing\n"
raise litellm.exceptions.NotFoundError( # type: ignore
message=error_str,
model=model,
llm_provider="",
)
# Extract the number of billion parameters from the model name # Extract the number of billion parameters from the model name
@ -337,8 +568,6 @@ def response_cost_calculator(
and custom_llm_provider is True and custom_llm_provider is True
): # override defaults if custom pricing is set ): # override defaults if custom pricing is set
base_model = model base_model = model
elif base_model is None:
base_model = model
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
response_cost = completion_cost( response_cost = completion_cost(
completion_response=response_object, completion_response=response_object,

View file

@ -26,7 +26,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 401 self.status_code = 401
self.message = message self.message = "litellm.AuthenticationError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -72,7 +72,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 404 self.status_code = 404
self.message = message self.message = "litellm.NotFoundError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -117,7 +117,7 @@ class BadRequestError(openai.BadRequestError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = "litellm.BadRequestError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -162,7 +162,7 @@ class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 422 self.status_code = 422
self.message = message self.message = "litellm.UnprocessableEntityError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -204,7 +204,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
request=request request=request
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
self.status_code = 408 self.status_code = 408
self.message = message self.message = "litellm.Timeout: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -241,7 +241,7 @@ class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 403 self.status_code = 403
self.message = message self.message = "litellm.PermissionDeniedError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -280,7 +280,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 429 self.status_code = 429
self.message = message self.message = "litellm.RateLimitError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -324,19 +324,22 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
message, message,
model, model,
llm_provider, llm_provider,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = "litellm.ContextWindowExceededError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
self.response = response or httpx.Response(status_code=400, request=request)
super().__init__( super().__init__(
message=self.message, message=self.message,
model=self.model, # type: ignore model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore llm_provider=self.llm_provider, # type: ignore
response=response, response=self.response,
litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):
@ -367,7 +370,7 @@ class RejectedRequestError(BadRequestError): # type: ignore
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = "litellm.RejectedRequestError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -379,6 +382,7 @@ class RejectedRequestError(BadRequestError): # type: ignore
model=self.model, # type: ignore model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore llm_provider=self.llm_provider, # type: ignore
response=response, response=response,
litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):
@ -405,19 +409,22 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
message, message,
model, model,
llm_provider, llm_provider,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = "litellm.ContentPolicyViolationError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
self.response = response or httpx.Response(status_code=500, request=request)
super().__init__( super().__init__(
message=self.message, message=self.message,
model=self.model, # type: ignore model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore llm_provider=self.llm_provider, # type: ignore
response=response, response=self.response,
litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):
@ -449,7 +456,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 503 self.status_code = 503
self.message = message self.message = "litellm.ServiceUnavailableError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -498,7 +505,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 500 self.status_code = 500
self.message = message self.message = "litellm.InternalServerError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -549,7 +556,7 @@ class APIError(openai.APIError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = "litellm.APIError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -586,7 +593,7 @@ class APIConnectionError(openai.APIConnectionError): # type: ignore
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.message = message self.message = "litellm.APIConnectionError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.status_code = 500 self.status_code = 500
@ -623,7 +630,7 @@ class APIResponseValidationError(openai.APIResponseValidationError): # type: ig
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.message = message self.message = "litellm.APIResponseValidationError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")

View file

@ -226,14 +226,6 @@ def _start_clickhouse():
response = client.query("DESCRIBE default.spend_logs") response = client.query("DESCRIBE default.spend_logs")
verbose_logger.debug(f"spend logs schema ={response.result_rows}") verbose_logger.debug(f"spend logs schema ={response.result_rows}")
# RUN Enterprise Clickhouse Setup
# TLDR: For Enterprise - we create views / aggregate tables for low latency reporting APIs
from litellm.proxy.enterprise.utils import _create_clickhouse_aggregate_tables
from litellm.proxy.enterprise.utils import _create_clickhouse_material_views
_create_clickhouse_aggregate_tables(client=client, table_names=table_names)
_create_clickhouse_material_views(client=client, table_names=table_names)
class ClickhouseLogger: class ClickhouseLogger:
# Class variables or attributes # Class variables or attributes

View file

@ -10,7 +10,7 @@ import traceback
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self) -> None:
pass pass
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):

View file

@ -0,0 +1,136 @@
"""
Functions for sending Email Alerts
"""
import os
from typing import Optional, List
from litellm.proxy._types import WebhookEvent
import asyncio
from litellm._logging import verbose_logger, verbose_proxy_logger
# we use this for the email header, please send a test email if you change this. verify it looks good on email
LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
LITELLM_SUPPORT_CONTACT = "support@berri.ai"
async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
verbose_logger.debug(
"Email Alerting: Getting all team members for team_id=%s", team_id
)
if team_id is None:
return []
from litellm.proxy.proxy_server import premium_user, prisma_client
if prisma_client is None:
raise Exception("Not connected to DB!")
team_row = await prisma_client.db.litellm_teamtable.find_unique(
where={
"team_id": team_id,
}
)
if team_row is None:
return []
_team_members = team_row.members_with_roles
verbose_logger.debug(
"Email Alerting: Got team members for team_id=%s Team Members: %s",
team_id,
_team_members,
)
_team_member_user_ids: List[str] = []
for member in _team_members:
if member and isinstance(member, dict) and member.get("user_id") is not None:
_team_member_user_ids.append(member.get("user_id"))
sql_query = """
SELECT user_email
FROM "LiteLLM_UserTable"
WHERE user_id = ANY($1::TEXT[]);
"""
_result = await prisma_client.db.query_raw(sql_query, _team_member_user_ids)
verbose_logger.debug("Email Alerting: Got all Emails for team, emails=%s", _result)
if _result is None:
return []
emails = []
for user in _result:
if user and isinstance(user, dict) and user.get("user_email", None) is not None:
emails.append(user.get("user_email"))
return emails
async def send_team_budget_alert(webhook_event: WebhookEvent) -> bool:
"""
Send an Email Alert to All Team Members when the Team Budget is crossed
Returns -> True if sent, False if not.
"""
from litellm.proxy.utils import send_email
from litellm.proxy.proxy_server import premium_user, prisma_client
_team_id = webhook_event.team_id
team_alias = webhook_event.team_alias
verbose_logger.debug(
"Email Alerting: Sending Team Budget Alert for team=%s", team_alias
)
email_logo_url = os.getenv("SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None))
email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
# await self._check_if_using_premium_email_feature(
# premium_user, email_logo_url, email_support_contact
# )
if email_logo_url is None:
email_logo_url = LITELLM_LOGO_URL
if email_support_contact is None:
email_support_contact = LITELLM_SUPPORT_CONTACT
recipient_emails = await get_all_team_member_emails(_team_id)
recipient_emails_str: str = ",".join(recipient_emails)
verbose_logger.debug(
"Email Alerting: Sending team budget alert to %s", recipient_emails_str
)
event_name = webhook_event.event_message
max_budget = webhook_event.max_budget
email_html_content = "Alert from LiteLLM Server"
if recipient_emails_str is None:
verbose_proxy_logger.error(
"Email Alerting: Trying to send email alert to no recipient, got recipient_emails=%s",
recipient_emails_str,
)
email_html_content = f"""
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" /> <br/><br/><br/>
Budget Crossed for Team <b> {team_alias} </b> <br/> <br/>
Your Teams LLM API usage has crossed it's <b> budget of ${max_budget} </b>, current spend is <b>${webhook_event.spend}</b><br /> <br />
API requests will be rejected until either (a) you increase your budget or (b) your budget gets reset <br /> <br />
If you have any questions, please send an email to {email_support_contact} <br /> <br />
Best, <br />
The LiteLLM team <br />
"""
email_event = {
"to": recipient_emails_str,
"subject": f"LiteLLM {event_name} for Team {team_alias}",
"html": email_html_content,
}
await send_email(
receiver_email=email_event["to"],
subject=email_event["subject"],
html=email_event["html"],
)
return False

View file

@ -1,13 +1,19 @@
# What is this? # What is this?
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639 ## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
import dotenv, os, json import json
import os
import traceback
import uuid
from typing import Literal, Optional
import dotenv
import httpx
import litellm import litellm
import traceback, httpx from litellm import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
import uuid
from typing import Optional, Literal
def get_utc_datetime(): def get_utc_datetime():
@ -143,6 +149,7 @@ class LagoLogger(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
verbose_logger.debug("ENTERS LAGO CALLBACK")
_url = os.getenv("LAGO_API_BASE") _url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance( assert _url is not None and isinstance(
_url, str _url, str

View file

@ -1,11 +1,13 @@
#### What this does #### #### What this does ####
# On success, logs events to Langfuse # On success, logs events to Langfuse
import os
import copy import copy
import os
import traceback import traceback
from packaging.version import Version from packaging.version import Version
from litellm._logging import verbose_logger
import litellm import litellm
from litellm._logging import verbose_logger
class LangFuseLogger: class LangFuseLogger:
@ -14,8 +16,8 @@ class LangFuseLogger:
self, langfuse_public_key=None, langfuse_secret=None, flush_interval=1 self, langfuse_public_key=None, langfuse_secret=None, flush_interval=1
): ):
try: try:
from langfuse import Langfuse
import langfuse import langfuse
from langfuse import Langfuse
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m" f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
@ -167,7 +169,7 @@ class LangFuseLogger:
or isinstance(response_obj, litellm.EmbeddingResponse) or isinstance(response_obj, litellm.EmbeddingResponse)
): ):
input = prompt input = prompt
output = response_obj["data"] output = None
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse response_obj, litellm.ModelResponse
): ):
@ -251,7 +253,7 @@ class LangFuseLogger:
input, input,
response_obj, response_obj,
): ):
from langfuse.model import CreateTrace, CreateGeneration from langfuse.model import CreateGeneration, CreateTrace
verbose_logger.warning( verbose_logger.warning(
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1" "Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
@ -528,31 +530,14 @@ class LangFuseLogger:
"version": clean_metadata.pop("version", None), "version": clean_metadata.pop("version", None),
} }
parent_observation_id = metadata.get("parent_observation_id", None)
if parent_observation_id is not None:
generation_params["parent_observation_id"] = parent_observation_id
if supports_prompt: if supports_prompt:
user_prompt = clean_metadata.pop("prompt", None) generation_params = _add_prompt_to_generation_params(
if user_prompt is None: generation_params=generation_params, clean_metadata=clean_metadata
pass
elif isinstance(user_prompt, dict):
from langfuse.model import (
TextPromptClient,
ChatPromptClient,
Prompt_Text,
Prompt_Chat,
) )
if user_prompt.get("type", "") == "chat":
_prompt_chat = Prompt_Chat(**user_prompt)
generation_params["prompt"] = ChatPromptClient(
prompt=_prompt_chat
)
elif user_prompt.get("type", "") == "text":
_prompt_text = Prompt_Text(**user_prompt)
generation_params["prompt"] = TextPromptClient(
prompt=_prompt_text
)
else:
generation_params["prompt"] = user_prompt
if output is not None and isinstance(output, str) and level == "ERROR": if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["status_message"] = output generation_params["status_message"] = output
@ -565,5 +550,58 @@ class LangFuseLogger:
return generation_client.trace_id, generation_id return generation_client.trace_id, generation_id
except Exception as e: except Exception as e:
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}") verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
return None, None return None, None
def _add_prompt_to_generation_params(
generation_params: dict, clean_metadata: dict
) -> dict:
from langfuse.model import (
ChatPromptClient,
Prompt_Chat,
Prompt_Text,
TextPromptClient,
)
user_prompt = clean_metadata.pop("prompt", None)
if user_prompt is None:
pass
elif isinstance(user_prompt, dict):
if user_prompt.get("type", "") == "chat":
_prompt_chat = Prompt_Chat(**user_prompt)
generation_params["prompt"] = ChatPromptClient(prompt=_prompt_chat)
elif user_prompt.get("type", "") == "text":
_prompt_text = Prompt_Text(**user_prompt)
generation_params["prompt"] = TextPromptClient(prompt=_prompt_text)
elif "version" in user_prompt and "prompt" in user_prompt:
# prompts
if isinstance(user_prompt["prompt"], str):
_prompt_obj = Prompt_Text(
name=user_prompt["name"],
prompt=user_prompt["prompt"],
version=user_prompt["version"],
config=user_prompt.get("config", None),
)
generation_params["prompt"] = TextPromptClient(prompt=_prompt_obj)
elif isinstance(user_prompt["prompt"], list):
_prompt_obj = Prompt_Chat(
name=user_prompt["name"],
prompt=user_prompt["prompt"],
version=user_prompt["version"],
config=user_prompt.get("config", None),
)
generation_params["prompt"] = ChatPromptClient(prompt=_prompt_obj)
else:
verbose_logger.error(
"[Non-blocking] Langfuse Logger: Invalid prompt format"
)
else:
verbose_logger.error(
"[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse"
)
else:
generation_params["prompt"] = user_prompt
return generation_params

View file

@ -105,7 +105,6 @@ class LunaryLogger:
end_time=datetime.now(timezone.utc), end_time=datetime.now(timezone.utc),
error=None, error=None,
): ):
# Method definition
try: try:
print_verbose(f"Lunary Logging - Logging request for model {model}") print_verbose(f"Lunary Logging - Logging request for model {model}")
@ -114,10 +113,9 @@ class LunaryLogger:
metadata = litellm_params.get("metadata", {}) or {} metadata = litellm_params.get("metadata", {}) or {}
if optional_params: if optional_params:
# merge into extra
extra = {**extra, **optional_params} extra = {**extra, **optional_params}
tags = litellm_params.pop("tags", None) or [] tags = metadata.get("tags", None)
if extra: if extra:
extra.pop("extra_body", None) extra.pop("extra_body", None)

View file

@ -1,22 +1,29 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
import litellm from functools import wraps
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm.integrations.custom_logger import CustomLogger import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.services import ServiceLoggerPayload from litellm.types.services import ServiceLoggerPayload
from typing import Union, Optional, TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
from litellm.proxy._types import (
ManagementEndpointLoggingPayload as _ManagementEndpointLoggingPayload,
)
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span Span = _Span
UserAPIKeyAuth = _UserAPIKeyAuth UserAPIKeyAuth = _UserAPIKeyAuth
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
else: else:
Span = Any Span = Any
UserAPIKeyAuth = Any UserAPIKeyAuth = Any
ManagementEndpointLoggingPayload = Any
LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm") LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm")
@ -101,8 +108,9 @@ class OpenTelemetry(CustomLogger):
start_time: Optional[datetime] = None, start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None, end_time: Optional[datetime] = None,
): ):
from opentelemetry import trace
from datetime import datetime from datetime import datetime
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
_start_time_ns = start_time _start_time_ns = start_time
@ -139,8 +147,9 @@ class OpenTelemetry(CustomLogger):
start_time: Optional[datetime] = None, start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None, end_time: Optional[datetime] = None,
): ):
from opentelemetry import trace
from datetime import datetime from datetime import datetime
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
_start_time_ns = start_time _start_time_ns = start_time
@ -173,8 +182,8 @@ class OpenTelemetry(CustomLogger):
async def async_post_call_failure_hook( async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
): ):
from opentelemetry.trace import Status, StatusCode
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
parent_otel_span = user_api_key_dict.parent_otel_span parent_otel_span = user_api_key_dict.parent_otel_span
if parent_otel_span is not None: if parent_otel_span is not None:
@ -196,8 +205,8 @@ class OpenTelemetry(CustomLogger):
parent_otel_span.end(end_time=self._to_ns(datetime.now())) parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_sucess(self, kwargs, response_obj, start_time, end_time): def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
verbose_logger.debug( verbose_logger.debug(
"OpenTelemetry Logger: Logging kwargs: %s, OTEL config settings=%s", "OpenTelemetry Logger: Logging kwargs: %s, OTEL config settings=%s",
@ -247,9 +256,10 @@ class OpenTelemetry(CustomLogger):
span.end(end_time=self._to_ns(end_time)) span.end(end_time=self._to_ns(end_time))
def set_tools_attributes(self, span: Span, tools): def set_tools_attributes(self, span: Span, tools):
from opentelemetry.semconv.ai import SpanAttributes
import json import json
from litellm.proxy._types import SpanAttributes
if not tools: if not tools:
return return
@ -272,7 +282,7 @@ class OpenTelemetry(CustomLogger):
pass pass
def set_attributes(self, span: Span, kwargs, response_obj): def set_attributes(self, span: Span, kwargs, response_obj):
from opentelemetry.semconv.ai import SpanAttributes from litellm.proxy._types import SpanAttributes
optional_params = kwargs.get("optional_params", {}) optional_params = kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
@ -314,7 +324,7 @@ class OpenTelemetry(CustomLogger):
) )
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_IS_STREAMING, optional_params.get("stream", False) SpanAttributes.LLM_IS_STREAMING, str(optional_params.get("stream", False))
) )
if optional_params.get("tools"): if optional_params.get("tools"):
@ -407,7 +417,7 @@ class OpenTelemetry(CustomLogger):
) )
def set_raw_request_attributes(self, span: Span, kwargs, response_obj): def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
from opentelemetry.semconv.ai import SpanAttributes from litellm.proxy._types import SpanAttributes
optional_params = kwargs.get("optional_params", {}) optional_params = kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
@ -433,7 +443,7 @@ class OpenTelemetry(CustomLogger):
############################################# #############################################
########## LLM Response Attributes ########## ########## LLM Response Attributes ##########
############################################# #############################################
if _raw_response: if _raw_response and isinstance(_raw_response, str):
# cast sr -> dict # cast sr -> dict
import json import json
@ -454,11 +464,28 @@ class OpenTelemetry(CustomLogger):
def _get_span_name(self, kwargs): def _get_span_name(self, kwargs):
return LITELLM_REQUEST_SPAN_NAME return LITELLM_REQUEST_SPAN_NAME
def _get_span_context(self, kwargs): def get_traceparent_from_header(self, headers):
if headers is None:
return None
_traceparent = headers.get("traceparent", None)
if _traceparent is None:
return None
from opentelemetry.trace.propagation.tracecontext import ( from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator, TraceContextTextMapPropagator,
) )
verbose_logger.debug("OpenTelemetry: GOT A TRACEPARENT {}".format(_traceparent))
propagator = TraceContextTextMapPropagator()
_parent_context = propagator.extract(carrier={"traceparent": _traceparent})
verbose_logger.debug("OpenTelemetry: PARENT CONTEXT {}".format(_parent_context))
return _parent_context
def _get_span_context(self, kwargs):
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
@ -482,17 +509,17 @@ class OpenTelemetry(CustomLogger):
return TraceContextTextMapPropagator().extract(carrier=carrier), None return TraceContextTextMapPropagator().extract(carrier=carrier), None
def _get_span_processor(self): def _get_span_processor(self):
from opentelemetry.sdk.trace.export import ( from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
SpanExporter, OTLPSpanExporter as OTLPSpanExporterGRPC,
SimpleSpanProcessor,
BatchSpanProcessor,
ConsoleSpanExporter,
) )
from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP, OTLPSpanExporter as OTLPSpanExporterHTTP,
) )
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( from opentelemetry.sdk.trace.export import (
OTLPSpanExporter as OTLPSpanExporterGRPC, BatchSpanProcessor,
ConsoleSpanExporter,
SimpleSpanProcessor,
SpanExporter,
) )
verbose_logger.debug( verbose_logger.debug(
@ -545,3 +572,93 @@ class OpenTelemetry(CustomLogger):
self.OTEL_EXPORTER, self.OTEL_EXPORTER,
) )
return BatchSpanProcessor(ConsoleSpanExporter()) return BatchSpanProcessor(ConsoleSpanExporter())
async def async_management_endpoint_success_hook(
self,
logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None,
):
from datetime import datetime
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
_start_time_ns = logging_payload.start_time
_end_time_ns = logging_payload.end_time
start_time = logging_payload.start_time
end_time = logging_payload.end_time
if isinstance(start_time, float):
_start_time_ns = int(int(start_time) * 1e9)
else:
_start_time_ns = self._to_ns(start_time)
if isinstance(end_time, float):
_end_time_ns = int(int(end_time) * 1e9)
else:
_end_time_ns = self._to_ns(end_time)
if parent_otel_span is not None:
_span_name = logging_payload.route
management_endpoint_span = self.tracer.start_span(
name=_span_name,
context=trace.set_span_in_context(parent_otel_span),
start_time=_start_time_ns,
)
_request_data = logging_payload.request_data
if _request_data is not None:
for key, value in _request_data.items():
management_endpoint_span.set_attribute(f"request.{key}", value)
_response = logging_payload.response
if _response is not None:
for key, value in _response.items():
management_endpoint_span.set_attribute(f"response.{key}", value)
management_endpoint_span.set_status(Status(StatusCode.OK))
management_endpoint_span.end(end_time=_end_time_ns)
async def async_management_endpoint_failure_hook(
self,
logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None,
):
from datetime import datetime
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
_start_time_ns = logging_payload.start_time
_end_time_ns = logging_payload.end_time
start_time = logging_payload.start_time
end_time = logging_payload.end_time
if isinstance(start_time, float):
_start_time_ns = int(int(start_time) * 1e9)
else:
_start_time_ns = self._to_ns(start_time)
if isinstance(end_time, float):
_end_time_ns = int(int(end_time) * 1e9)
else:
_end_time_ns = self._to_ns(end_time)
if parent_otel_span is not None:
_span_name = logging_payload.route
management_endpoint_span = self.tracer.start_span(
name=_span_name,
context=trace.set_span_in_context(parent_otel_span),
start_time=_start_time_ns,
)
_request_data = logging_payload.request_data
if _request_data is not None:
for key, value in _request_data.items():
management_endpoint_span.set_attribute(f"request.{key}", value)
_exception = logging_payload.exception
management_endpoint_span.set_attribute(f"exception", str(_exception))
management_endpoint_span.set_status(Status(StatusCode.ERROR))
management_endpoint_span.end(end_time=_end_time_ns)

View file

@ -8,6 +8,7 @@ import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
from typing import Optional, Union
class PrometheusLogger: class PrometheusLogger:
@ -17,33 +18,76 @@ class PrometheusLogger:
**kwargs, **kwargs,
): ):
try: try:
from prometheus_client import Counter from prometheus_client import Counter, Gauge
self.litellm_llm_api_failed_requests_metric = Counter( self.litellm_llm_api_failed_requests_metric = Counter(
name="litellm_llm_api_failed_requests_metric", name="litellm_llm_api_failed_requests_metric",
documentation="Total number of failed LLM API calls via litellm", documentation="Total number of failed LLM API calls via litellm",
labelnames=["end_user", "hashed_api_key", "model", "team", "user"], labelnames=[
"end_user",
"hashed_api_key",
"model",
"team",
"team_alias",
"user",
],
) )
self.litellm_requests_metric = Counter( self.litellm_requests_metric = Counter(
name="litellm_requests_metric", name="litellm_requests_metric",
documentation="Total number of LLM calls to litellm", documentation="Total number of LLM calls to litellm",
labelnames=["end_user", "hashed_api_key", "model", "team", "user"], labelnames=[
"end_user",
"hashed_api_key",
"model",
"team",
"team_alias",
"user",
],
) )
# Counter for spend # Counter for spend
self.litellm_spend_metric = Counter( self.litellm_spend_metric = Counter(
"litellm_spend_metric", "litellm_spend_metric",
"Total spend on LLM requests", "Total spend on LLM requests",
labelnames=["end_user", "hashed_api_key", "model", "team", "user"], labelnames=[
"end_user",
"hashed_api_key",
"model",
"team",
"team_alias",
"user",
],
) )
# Counter for total_output_tokens # Counter for total_output_tokens
self.litellm_tokens_metric = Counter( self.litellm_tokens_metric = Counter(
"litellm_total_tokens", "litellm_total_tokens",
"Total number of input + output tokens from LLM requests", "Total number of input + output tokens from LLM requests",
labelnames=["end_user", "hashed_api_key", "model", "team", "user"], labelnames=[
"end_user",
"hashed_api_key",
"model",
"team",
"team_alias",
"user",
],
) )
# Remaining Budget for Team
self.litellm_remaining_team_budget_metric = Gauge(
"litellm_remaining_team_budget_metric",
"Remaining budget for team",
labelnames=["team_id", "team_alias"],
)
# Remaining Budget for API Key
self.litellm_remaining_api_key_budget_metric = Gauge(
"litellm_remaining_api_key_budget_metric",
"Remaining budget for api key",
labelnames=["hashed_api_key", "api_key_alias"],
)
except Exception as e: except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}") print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e raise e
@ -51,7 +95,9 @@ class PrometheusLogger:
async def _async_log_event( async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, user_id self, kwargs, response_obj, start_time, end_time, print_verbose, user_id
): ):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose) self.log_event(
kwargs, response_obj, start_time, end_time, user_id, print_verbose
)
def log_event( def log_event(
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
@ -72,9 +118,36 @@ class PrometheusLogger:
"user_api_key_user_id", None "user_api_key_user_id", None
) )
user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None) user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None)
user_api_key_alias = litellm_params.get("metadata", {}).get(
"user_api_key_alias", None
)
user_api_team = litellm_params.get("metadata", {}).get( user_api_team = litellm_params.get("metadata", {}).get(
"user_api_key_team_id", None "user_api_key_team_id", None
) )
user_api_team_alias = litellm_params.get("metadata", {}).get(
"user_api_key_team_alias", None
)
_team_spend = litellm_params.get("metadata", {}).get(
"user_api_key_team_spend", None
)
_team_max_budget = litellm_params.get("metadata", {}).get(
"user_api_key_team_max_budget", None
)
_remaining_team_budget = safe_get_remaining_budget(
max_budget=_team_max_budget, spend=_team_spend
)
_api_key_spend = litellm_params.get("metadata", {}).get(
"user_api_key_spend", None
)
_api_key_max_budget = litellm_params.get("metadata", {}).get(
"user_api_key_max_budget", None
)
_remaining_api_key_budget = safe_get_remaining_budget(
max_budget=_api_key_max_budget, spend=_api_key_spend
)
if response_obj is not None: if response_obj is not None:
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0) tokens_used = response_obj.get("usage", {}).get("total_tokens", 0)
else: else:
@ -94,19 +167,47 @@ class PrometheusLogger:
user_api_key = hash_token(user_api_key) user_api_key = hash_token(user_api_key)
self.litellm_requests_metric.labels( self.litellm_requests_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id end_user_id,
user_api_key,
model,
user_api_team,
user_api_team_alias,
user_id,
).inc() ).inc()
self.litellm_spend_metric.labels( self.litellm_spend_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id end_user_id,
user_api_key,
model,
user_api_team,
user_api_team_alias,
user_id,
).inc(response_cost) ).inc(response_cost)
self.litellm_tokens_metric.labels( self.litellm_tokens_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id end_user_id,
user_api_key,
model,
user_api_team,
user_api_team_alias,
user_id,
).inc(tokens_used) ).inc(tokens_used)
self.litellm_remaining_team_budget_metric.labels(
user_api_team, user_api_team_alias
).set(_remaining_team_budget)
self.litellm_remaining_api_key_budget_metric.labels(
user_api_key, user_api_key_alias
).set(_remaining_api_key_budget)
### FAILURE INCREMENT ### ### FAILURE INCREMENT ###
if "exception" in kwargs: if "exception" in kwargs:
self.litellm_llm_api_failed_requests_metric.labels( self.litellm_llm_api_failed_requests_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id end_user_id,
user_api_key,
model,
user_api_team,
user_api_team_alias,
user_id,
).inc() ).inc()
except Exception as e: except Exception as e:
verbose_logger.error( verbose_logger.error(
@ -114,3 +215,15 @@ class PrometheusLogger:
) )
verbose_logger.debug(traceback.format_exc()) verbose_logger.debug(traceback.format_exc())
pass pass
def safe_get_remaining_budget(
max_budget: Optional[float], spend: Optional[float]
) -> float:
if max_budget is None:
return float("inf")
if spend is None:
return max_budget
return max_budget - spend

View file

@ -330,6 +330,7 @@ class SlackAlerting(CustomLogger):
messages = "Message not logged. litellm.redact_messages_in_exceptions=True" messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
alerting_metadata: dict = {}
if time_difference_float > self.alerting_threshold: if time_difference_float > self.alerting_threshold:
# add deployment latencies to alert # add deployment latencies to alert
if ( if (
@ -337,7 +338,7 @@ class SlackAlerting(CustomLogger):
and "litellm_params" in kwargs and "litellm_params" in kwargs
and "metadata" in kwargs["litellm_params"] and "metadata" in kwargs["litellm_params"]
): ):
_metadata = kwargs["litellm_params"]["metadata"] _metadata: dict = kwargs["litellm_params"]["metadata"]
request_info = litellm.utils._add_key_name_and_team_to_alert( request_info = litellm.utils._add_key_name_and_team_to_alert(
request_info=request_info, metadata=_metadata request_info=request_info, metadata=_metadata
) )
@ -349,10 +350,14 @@ class SlackAlerting(CustomLogger):
request_info += ( request_info += (
f"\nAvailable Deployment Latencies\n{_deployment_latency_map}" f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
) )
if "alerting_metadata" in _metadata:
alerting_metadata = _metadata["alerting_metadata"]
await self.send_alert( await self.send_alert(
message=slow_message + request_info, message=slow_message + request_info,
level="Low", level="Low",
alert_type="llm_too_slow", alert_type="llm_too_slow",
alerting_metadata=alerting_metadata,
) )
async def async_update_daily_reports( async def async_update_daily_reports(
@ -540,7 +545,12 @@ class SlackAlerting(CustomLogger):
message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s" message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s"
# send alert # send alert
await self.send_alert(message=message, level="Low", alert_type="daily_reports") await self.send_alert(
message=message,
level="Low",
alert_type="daily_reports",
alerting_metadata={},
)
return True return True
@ -582,6 +592,7 @@ class SlackAlerting(CustomLogger):
await asyncio.sleep( await asyncio.sleep(
self.alerting_threshold self.alerting_threshold
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests ) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
alerting_metadata: dict = {}
if ( if (
request_data is not None request_data is not None
and request_data.get("litellm_status", "") != "success" and request_data.get("litellm_status", "") != "success"
@ -606,7 +617,7 @@ class SlackAlerting(CustomLogger):
): ):
# In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data`` # In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
# in that case we fallback to the api base set in the request metadata # in that case we fallback to the api base set in the request metadata
_metadata = request_data["metadata"] _metadata: dict = request_data["metadata"]
_api_base = _metadata.get("api_base", "") _api_base = _metadata.get("api_base", "")
request_info = litellm.utils._add_key_name_and_team_to_alert( request_info = litellm.utils._add_key_name_and_team_to_alert(
@ -615,6 +626,9 @@ class SlackAlerting(CustomLogger):
if _api_base is None: if _api_base is None:
_api_base = "" _api_base = ""
if "alerting_metadata" in _metadata:
alerting_metadata = _metadata["alerting_metadata"]
request_info += f"\nAPI Base: `{_api_base}`" request_info += f"\nAPI Base: `{_api_base}`"
# only alert hanging responses if they have not been marked as success # only alert hanging responses if they have not been marked as success
alerting_message = ( alerting_message = (
@ -640,6 +654,7 @@ class SlackAlerting(CustomLogger):
message=alerting_message + request_info, message=alerting_message + request_info,
level="Medium", level="Medium",
alert_type="llm_requests_hanging", alert_type="llm_requests_hanging",
alerting_metadata=alerting_metadata,
) )
async def failed_tracking_alert(self, error_message: str): async def failed_tracking_alert(self, error_message: str):
@ -650,7 +665,10 @@ class SlackAlerting(CustomLogger):
result = await _cache.async_get_cache(key=_cache_key) result = await _cache.async_get_cache(key=_cache_key)
if result is None: if result is None:
await self.send_alert( await self.send_alert(
message=message, level="High", alert_type="budget_alerts" message=message,
level="High",
alert_type="budget_alerts",
alerting_metadata={},
) )
await _cache.async_set_cache( await _cache.async_set_cache(
key=_cache_key, key=_cache_key,
@ -680,7 +698,7 @@ class SlackAlerting(CustomLogger):
return return
if "budget_alerts" not in self.alert_types: if "budget_alerts" not in self.alert_types:
return return
_id: str = "default_id" # used for caching _id: Optional[str] = "default_id" # used for caching
user_info_json = user_info.model_dump(exclude_none=True) user_info_json = user_info.model_dump(exclude_none=True)
for k, v in user_info_json.items(): for k, v in user_info_json.items():
user_info_str = "\n{}: {}\n".format(k, v) user_info_str = "\n{}: {}\n".format(k, v)
@ -751,6 +769,7 @@ class SlackAlerting(CustomLogger):
level="High", level="High",
alert_type="budget_alerts", alert_type="budget_alerts",
user_info=webhook_event, user_info=webhook_event,
alerting_metadata={},
) )
await _cache.async_set_cache( await _cache.async_set_cache(
key=_cache_key, key=_cache_key,
@ -769,7 +788,13 @@ class SlackAlerting(CustomLogger):
response_cost: Optional[float], response_cost: Optional[float],
max_budget: Optional[float], max_budget: Optional[float],
): ):
if end_user_id is not None and token is not None and response_cost is not None: if (
self.alerting is not None
and "webhook" in self.alerting
and end_user_id is not None
and token is not None
and response_cost is not None
):
# log customer spend # log customer spend
event = WebhookEvent( event = WebhookEvent(
spend=response_cost, spend=response_cost,
@ -941,7 +966,10 @@ class SlackAlerting(CustomLogger):
) )
# send minor alert # send minor alert
await self.send_alert( await self.send_alert(
message=msg, level="Medium", alert_type="outage_alerts" message=msg,
level="Medium",
alert_type="outage_alerts",
alerting_metadata={},
) )
# set to true # set to true
outage_value["minor_alert_sent"] = True outage_value["minor_alert_sent"] = True
@ -963,7 +991,12 @@ class SlackAlerting(CustomLogger):
) )
# send minor alert # send minor alert
await self.send_alert(message=msg, level="High", alert_type="outage_alerts") await self.send_alert(
message=msg,
level="High",
alert_type="outage_alerts",
alerting_metadata={},
)
# set to true # set to true
outage_value["major_alert_sent"] = True outage_value["major_alert_sent"] = True
@ -1062,7 +1095,10 @@ class SlackAlerting(CustomLogger):
) )
# send minor alert # send minor alert
await self.send_alert( await self.send_alert(
message=msg, level="Medium", alert_type="outage_alerts" message=msg,
level="Medium",
alert_type="outage_alerts",
alerting_metadata={},
) )
# set to true # set to true
outage_value["minor_alert_sent"] = True outage_value["minor_alert_sent"] = True
@ -1081,7 +1117,10 @@ class SlackAlerting(CustomLogger):
) )
# send minor alert # send minor alert
await self.send_alert( await self.send_alert(
message=msg, level="High", alert_type="outage_alerts" message=msg,
level="High",
alert_type="outage_alerts",
alerting_metadata={},
) )
# set to true # set to true
outage_value["major_alert_sent"] = True outage_value["major_alert_sent"] = True
@ -1143,7 +1182,10 @@ Model Info:
""" """
alert_val = self.send_alert( alert_val = self.send_alert(
message=message, level="Low", alert_type="new_model_added" message=message,
level="Low",
alert_type="new_model_added",
alerting_metadata={},
) )
if alert_val is not None and asyncio.iscoroutine(alert_val): if alert_val is not None and asyncio.iscoroutine(alert_val):
@ -1159,6 +1201,9 @@ Model Info:
Currently only implemented for budget alerts Currently only implemented for budget alerts
Returns -> True if sent, False if not. Returns -> True if sent, False if not.
Raises Exception
- if WEBHOOK_URL is not set
""" """
webhook_url = os.getenv("WEBHOOK_URL", None) webhook_url = os.getenv("WEBHOOK_URL", None)
@ -1297,7 +1342,9 @@ Model Info:
verbose_proxy_logger.error("Error sending email alert %s", str(e)) verbose_proxy_logger.error("Error sending email alert %s", str(e))
return False return False
async def send_email_alert_using_smtp(self, webhook_event: WebhookEvent) -> bool: async def send_email_alert_using_smtp(
self, webhook_event: WebhookEvent, alert_type: str
) -> bool:
""" """
Sends structured Email alert to an SMTP server Sends structured Email alert to an SMTP server
@ -1306,7 +1353,6 @@ Model Info:
Returns -> True if sent, False if not. Returns -> True if sent, False if not.
""" """
from litellm.proxy.utils import send_email from litellm.proxy.utils import send_email
from litellm.proxy.proxy_server import premium_user, prisma_client from litellm.proxy.proxy_server import premium_user, prisma_client
email_logo_url = os.getenv( email_logo_url = os.getenv(
@ -1360,6 +1406,10 @@ Model Info:
subject=email_event["subject"], subject=email_event["subject"],
html=email_event["html"], html=email_event["html"],
) )
if webhook_event.event_group == "team":
from litellm.integrations.email_alerting import send_team_budget_alert
await send_team_budget_alert(webhook_event=webhook_event)
return False return False
@ -1368,6 +1418,7 @@ Model Info:
message: str, message: str,
level: Literal["Low", "Medium", "High"], level: Literal["Low", "Medium", "High"],
alert_type: Literal[AlertType], alert_type: Literal[AlertType],
alerting_metadata: dict,
user_info: Optional[WebhookEvent] = None, user_info: Optional[WebhookEvent] = None,
**kwargs, **kwargs,
): ):
@ -1401,7 +1452,9 @@ Model Info:
and user_info is not None and user_info is not None
): ):
# only send budget alerts over Email # only send budget alerts over Email
await self.send_email_alert_using_smtp(webhook_event=user_info) await self.send_email_alert_using_smtp(
webhook_event=user_info, alert_type=alert_type
)
if "slack" not in self.alerting: if "slack" not in self.alerting:
return return
@ -1425,6 +1478,9 @@ Model Info:
if kwargs: if kwargs:
for key, value in kwargs.items(): for key, value in kwargs.items():
formatted_message += f"\n\n{key}: `{value}`\n\n" formatted_message += f"\n\n{key}: `{value}`\n\n"
if alerting_metadata:
for key, value in alerting_metadata.items():
formatted_message += f"\n\n*Alerting Metadata*: \n{key}: `{value}`\n\n"
if _proxy_base_url is not None: if _proxy_base_url is not None:
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`" formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
@ -1440,7 +1496,7 @@ Model Info:
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None) slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
if slack_webhook_url is None: if slack_webhook_url is None:
raise Exception("Missing SLACK_WEBHOOK_URL from environment") raise ValueError("Missing SLACK_WEBHOOK_URL from environment")
payload = {"text": formatted_message} payload = {"text": formatted_message}
headers = {"Content-type": "application/json"} headers = {"Content-type": "application/json"}
@ -1453,7 +1509,7 @@ Model Info:
pass pass
else: else:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Error sending slack alert. Error=", response.text "Error sending slack alert. Error={}".format(response.text)
) )
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -1622,6 +1678,7 @@ Model Info:
message=_weekly_spend_message, message=_weekly_spend_message,
level="Low", level="Low",
alert_type="spend_reports", alert_type="spend_reports",
alerting_metadata={},
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e) verbose_proxy_logger.error("Error sending weekly spend report", e)
@ -1673,6 +1730,7 @@ Model Info:
message=_spend_message, message=_spend_message,
level="Low", level="Low",
alert_type="spend_reports", alert_type="spend_reports",
alerting_metadata={},
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e) verbose_proxy_logger.error("Error sending weekly spend report", e)

View file

@ -0,0 +1,41 @@
# What is this?
## Helper utilities for the model response objects
def map_finish_reason(
finish_reason: str,
): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null'
# anthropic mapping
if finish_reason == "stop_sequence":
return "stop"
# cohere mapping - https://docs.cohere.com/reference/generate
elif finish_reason == "COMPLETE":
return "stop"
elif finish_reason == "MAX_TOKENS": # cohere + vertex ai
return "length"
elif finish_reason == "ERROR_TOXIC":
return "content_filter"
elif (
finish_reason == "ERROR"
): # openai currently doesn't support an 'error' finish reason
return "stop"
# huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream
elif finish_reason == "eos_token" or finish_reason == "stop_sequence":
return "stop"
elif (
finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP"
): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',]
return "stop"
elif finish_reason == "SAFETY": # vertex ai
return "content_filter"
elif finish_reason == "STOP": # vertex ai
return "stop"
elif finish_reason == "end_turn" or finish_reason == "stop_sequence": # anthropic
return "stop"
elif finish_reason == "max_tokens": # anthropic
return "length"
elif finish_reason == "tool_use": # anthropic
return "tool_calls"
elif finish_reason == "content_filtered":
return "content_filter"
return finish_reason

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,82 @@
# What is this?
## Cost calculation for Google AI Studio / Vertex AI models
from typing import Literal, Tuple
import litellm
"""
Gemini pricing covers:
- token
- image
- audio
- video
"""
models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro"]
def _is_above_128k(tokens: float) -> bool:
if tokens > 128000:
return True
return False
def cost_per_token(
model: str,
custom_llm_provider: str,
prompt_tokens: float,
completion_tokens: float,
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, either "vertex_ai-*" or "gemini"
- prompt_tokens: float, the number of input tokens
- completion_tokens: float, the number of output tokens
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
Raises:
Exception if model requires >128k pricing, but model cost not mapped
"""
## GET MODEL INFO
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
## CALCULATE INPUT COST
if (
_is_above_128k(tokens=prompt_tokens)
and model not in models_without_dynamic_pricing
):
assert (
model_info["input_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model, model_info
)
prompt_cost = (
prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"]
)
else:
prompt_cost = prompt_tokens * model_info["input_cost_per_token"]
## CALCULATE OUTPUT COST
if (
_is_above_128k(tokens=completion_tokens)
and model not in models_without_dynamic_pricing
):
assert (
model_info["output_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model, model_info
)
completion_cost = (
completion_tokens * model_info["output_cost_per_token_above_128k_tokens"]
)
else:
completion_cost = completion_tokens * model_info["output_cost_per_token"]
return prompt_cost, completion_cost

View file

@ -0,0 +1,28 @@
from typing import Dict, Optional
def _ensure_extra_body_is_safe(extra_body: Optional[Dict]) -> Optional[Dict]:
"""
Ensure that the extra_body sent in the request is safe, otherwise users will see this error
"Object of type TextPromptClient is not JSON serializable
Relevant Issue: https://github.com/BerriAI/litellm/issues/4140
"""
if extra_body is None:
return None
if not isinstance(extra_body, dict):
return extra_body
if "metadata" in extra_body and isinstance(extra_body["metadata"], dict):
if "prompt" in extra_body["metadata"]:
_prompt = extra_body["metadata"].get("prompt")
# users can send Langfuse TextPromptClient objects, so we need to convert them to dicts
# Langfuse TextPromptClients have .__dict__ attribute
if _prompt is not None and hasattr(_prompt, "__dict__"):
extra_body["metadata"]["prompt"] = _prompt.__dict__
return extra_body

View file

@ -0,0 +1,71 @@
# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import copy
from typing import TYPE_CHECKING, Any
import litellm
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import (
Logging as _LiteLLMLoggingObject,
)
LiteLLMLoggingObject = _LiteLLMLoggingObject
else:
LiteLLMLoggingObject = Any
def redact_message_input_output_from_logging(
litellm_logging_obj: LiteLLMLoggingObject, result
):
"""
Removes messages, prompts, input, response from logging. This modifies the data in-place
only redacts when litellm.turn_off_message_logging == True
"""
# check if user opted out of logging message/response to callbacks
if litellm.turn_off_message_logging is not True:
return result
# remove messages, prompts, input, response from logging
litellm_logging_obj.model_call_details["messages"] = [
{"role": "user", "content": "redacted-by-litellm"}
]
litellm_logging_obj.model_call_details["prompt"] = ""
litellm_logging_obj.model_call_details["input"] = ""
# response cleaning
# ChatCompletion Responses
if (
litellm_logging_obj.stream is True
and "complete_streaming_response" in litellm_logging_obj.model_call_details
):
_streaming_response = litellm_logging_obj.model_call_details[
"complete_streaming_response"
]
for choice in _streaming_response.choices:
if isinstance(choice, litellm.Choices):
choice.message.content = "redacted-by-litellm"
elif isinstance(choice, litellm.utils.StreamingChoices):
choice.delta.content = "redacted-by-litellm"
else:
if result is not None:
if isinstance(result, litellm.ModelResponse):
# only deep copy litellm.ModelResponse
_result = copy.deepcopy(result)
if hasattr(_result, "choices") and _result.choices is not None:
for choice in _result.choices:
if isinstance(choice, litellm.Choices):
choice.message.content = "redacted-by-litellm"
elif isinstance(choice, litellm.utils.StreamingChoices):
choice.delta.content = "redacted-by-litellm"
return _result
# by default return result
return result

View file

@ -5,10 +5,16 @@ import requests, copy # type: ignore
import time import time
from functools import partial from functools import partial
from typing import Callable, Optional, List, Union from typing import Callable, Optional, List, Union
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper import litellm.litellm_core_utils
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
)
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore import httpx # type: ignore
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
@ -171,7 +177,7 @@ async def make_call(
logging_obj, logging_obj,
): ):
if client is None: if client is None:
client = AsyncHTTPHandler() # Create a new client if none provided client = _get_async_httpx_client() # Create a new client if none provided
response = await client.post(api_base, headers=headers, data=data, stream=True) response = await client.post(api_base, headers=headers, data=data, stream=True)
@ -201,7 +207,7 @@ class AnthropicChatCompletion(BaseLLM):
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: litellm.utils.Logging, logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],
@ -316,7 +322,7 @@ class AnthropicChatCompletion(BaseLLM):
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: litellm.utils.Logging, logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],
@ -463,9 +469,7 @@ class AnthropicChatCompletion(BaseLLM):
logger_fn=None, logger_fn=None,
headers={}, headers={},
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = AsyncHTTPHandler( async_handler = _get_async_httpx_client()
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await async_handler.post(api_base, headers=headers, json=data) response = await async_handler.post(api_base, headers=headers, json=data)
if stream and _is_function_call: if stream and _is_function_call:
return self.process_streaming_response( return self.process_streaming_response(

View file

@ -1,41 +1,58 @@
from typing import Optional, Union, Any, Literal, Coroutine, Iterable import asyncio
from typing_extensions import overload import json
import types, requests
from .base import BaseLLM
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
TranscriptionResponse,
get_secret,
UnsupportedParamsError,
)
from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig
import litellm, json
import httpx # type: ignore
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid
import os import os
import types
import uuid
from typing import (
Any,
BinaryIO,
Callable,
Coroutine,
Iterable,
List,
Literal,
Optional,
Union,
)
import httpx # type: ignore
import requests
from openai import AsyncAzureOpenAI, AzureOpenAI
from typing_extensions import overload
import litellm
from litellm import OpenAIConfig
from litellm.caching import DualCache
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
TranscriptionResponse,
UnsupportedParamsError,
convert_to_model_response_object,
get_secret,
)
from ..types.llms.openai import ( from ..types.llms.openai import (
AsyncCursorPage,
AssistantToolParam,
SyncCursorPage,
Assistant, Assistant,
MessageData,
OpenAIMessage,
OpenAICreateThreadParamsMessage,
Thread,
AssistantToolParam,
Run,
AssistantEventHandler, AssistantEventHandler,
AssistantStreamManager,
AssistantToolParam,
AsyncAssistantEventHandler, AsyncAssistantEventHandler,
AsyncAssistantStreamManager, AsyncAssistantStreamManager,
AssistantStreamManager, AsyncCursorPage,
MessageData,
OpenAICreateThreadParamsMessage,
OpenAIMessage,
Run,
SyncCursorPage,
Thread,
) )
from .base import BaseLLM
from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport
azure_ad_cache = DualCache()
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
@ -309,9 +326,12 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
def get_azure_ad_token_from_oidc(azure_ad_token: str): def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None) azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant = os.getenv("AZURE_TENANT_ID", None) azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv(
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
)
if azure_client_id is None or azure_tenant is None: if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError( raise AzureOpenAIError(
status_code=422, status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set", message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
@ -325,8 +345,21 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
message="OIDC token could not be retrieved from secret manager.", message="OIDC token could not be retrieved from secret manager.",
) )
azure_ad_token_cache_key = json.dumps(
{
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
}
)
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
return azure_ad_token_access_token
req_token = httpx.post( req_token = httpx.post(
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token", f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={ data={
"client_id": azure_client_id, "client_id": azure_client_id,
"grant_type": "client_credentials", "grant_type": "client_credentials",
@ -342,12 +375,27 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
message=req_token.text, message=req_token.text,
) )
possible_azure_ad_token = req_token.json().get("access_token", None) azure_ad_token_json = req_token.json()
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
if possible_azure_ad_token is None: if azure_ad_token_access_token is None:
raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned") raise AzureOpenAIError(
status_code=422, message="Azure AD Token access_token not returned"
)
return possible_azure_ad_token if azure_ad_token_expires_in is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token expires_in not returned"
)
azure_ad_cache.set_cache(
key=azure_ad_token_cache_key,
value=azure_ad_token_access_token,
ttl=azure_ad_token_expires_in,
)
return azure_ad_token_access_token
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
@ -619,6 +667,8 @@ class AzureChatCompletion(BaseLLM):
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except asyncio.CancelledError as e:
raise AzureOpenAIError(status_code=500, message=str(e))
except Exception as e: except Exception as e:
if hasattr(e, "status_code"): if hasattr(e, "status_code"):
raise e raise e

View file

@ -2,7 +2,7 @@
import litellm import litellm
import httpx, requests import httpx, requests
from typing import Optional, Union from typing import Optional, Union
from litellm.utils import Logging from litellm.litellm_core_utils.litellm_logging import Logging
class BaseLLM: class BaseLLM:
@ -27,6 +27,25 @@ class BaseLLM:
""" """
return model_response return model_response
def process_text_completion_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.TextCompletionResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]:
"""
Helper function to process the response across sync + async completion calls
"""
return model_response
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session

View file

@ -1,25 +1,27 @@
import json, copy, types import copy
import json
import os import os
import time
import types
import uuid
from enum import Enum from enum import Enum
import time, uuid from typing import Any, Callable, List, Optional, Union
from typing import Callable, Optional, Any, Union, List
import httpx
import litellm import litellm
from litellm.utils import ( from litellm.litellm_core_utils.core_helpers import map_finish_reason
ModelResponse, from litellm.types.utils import ImageResponse, ModelResponse, Usage
get_secret, from litellm.utils import get_secret
Usage,
ImageResponse,
map_finish_reason,
)
from .prompt_templates.factory import ( from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt, construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags, extract_between_tags,
parse_xml_params, parse_xml_params,
contains_tag, prompt_factory,
) )
import httpx
class BedrockError(Exception): class BedrockError(Exception):
@ -633,7 +635,11 @@ def init_bedrock_client(
config = boto3.session.Config() config = boto3.session.Config()
### CHECK STS ### ### CHECK STS ###
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None: if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token) oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None: if oidc_token is None:
@ -642,9 +648,7 @@ def init_bedrock_client(
status_code=401, status_code=401,
) )
sts_client = boto3.client( sts_client = boto3.client("sts")
"sts"
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
@ -726,7 +730,7 @@ def init_bedrock_client(
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts
if provider == "anthropic" or provider == "amazon": chat_template_provider = ["anthropic", "amazon", "mistral", "meta"]
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
@ -737,14 +741,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
messages=messages, messages=messages,
) )
else: else:
prompt = prompt_factory( if provider in chat_template_provider:
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory( prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock" model=model, messages=messages, custom_llm_provider="bedrock"
) )

View file

@ -22,13 +22,12 @@ from typing import (
from litellm.utils import ( from litellm.utils import (
ModelResponse, ModelResponse,
Usage, Usage,
map_finish_reason,
CustomStreamWrapper, CustomStreamWrapper,
Message,
Choices,
get_secret, get_secret,
Logging,
) )
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.types.utils import Message, Choices
import litellm, uuid import litellm, uuid
from .prompt_templates.factory import ( from .prompt_templates.factory import (
prompt_factory, prompt_factory,
@ -41,7 +40,12 @@ from .prompt_templates.factory import (
_bedrock_converse_messages_pt, _bedrock_converse_messages_pt,
_bedrock_tools_pt, _bedrock_tools_pt,
) )
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
)
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
@ -51,7 +55,11 @@ from litellm.types.llms.openai import (
ChatCompletionResponseMessage, ChatCompletionResponseMessage,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaChunk,
) )
from litellm.caching import DualCache
iam_cache = DualCache()
class AmazonCohereChatConfig: class AmazonCohereChatConfig:
@ -164,7 +172,7 @@ async def make_call(
logging_obj, logging_obj,
): ):
if client is None: if client is None:
client = AsyncHTTPHandler() # Create a new client if none provided client = _get_async_httpx_client() # Create a new client if none provided
response = await client.post(api_base, headers=headers, data=data, stream=True) response = await client.post(api_base, headers=headers, data=data, stream=True)
@ -195,7 +203,7 @@ def make_sync_call(
logging_obj, logging_obj,
): ):
if client is None: if client is None:
client = HTTPHandler() # Create a new client if none provided client = _get_httpx_client() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True) response = client.post(api_base, headers=headers, data=data, stream=True)
@ -329,6 +337,17 @@ class BedrockLLM(BaseLLM):
and aws_role_name is not None and aws_role_name is not None
and aws_session_name is not None and aws_session_name is not None
): ):
iam_creds_cache_key = json.dumps(
{
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
}
)
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token) oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None: if oidc_token is None:
@ -337,7 +356,11 @@ class BedrockLLM(BaseLLM):
status_code=401, status_code=401,
) )
sts_client = boto3.client("sts") sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com",
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
@ -348,14 +371,26 @@ class BedrockLLM(BaseLLM):
DurationSeconds=3600, DurationSeconds=3600,
) )
session = boto3.Session( iam_creds_dict = {
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], "aws_secret_access_key": sts_response["Credentials"][
aws_session_token=sts_response["Credentials"]["SessionToken"], "SecretAccessKey"
region_name=aws_region_name, ],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
iam_cache.set_cache(
key=iam_creds_cache_key,
value=json.dumps(iam_creds_dict),
ttl=3600 - 60,
) )
return session.get_credentials() session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None: elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",
@ -958,7 +993,7 @@ class BedrockLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
self.client = HTTPHandler(**_params) # type: ignore self.client = _get_httpx_client(_params) # type: ignore
else: else:
self.client = client self.client = client
if (stream is not None and stream == True) and provider != "ai21": if (stream is not None and stream == True) and provider != "ai21":
@ -1040,7 +1075,7 @@ class BedrockLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore client = _get_async_httpx_client(_params) # type: ignore
else: else:
client = client # type: ignore client = client # type: ignore
@ -1420,6 +1455,17 @@ class BedrockConverseLLM(BaseLLM):
and aws_role_name is not None and aws_role_name is not None
and aws_session_name is not None and aws_session_name is not None
): ):
iam_creds_cache_key = json.dumps(
{
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
}
)
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token) oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None: if oidc_token is None:
@ -1428,7 +1474,11 @@ class BedrockConverseLLM(BaseLLM):
status_code=401, status_code=401,
) )
sts_client = boto3.client("sts") sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com",
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
@ -1439,14 +1489,26 @@ class BedrockConverseLLM(BaseLLM):
DurationSeconds=3600, DurationSeconds=3600,
) )
session = boto3.Session( iam_creds_dict = {
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], "aws_secret_access_key": sts_response["Credentials"][
aws_session_token=sts_response["Credentials"]["SessionToken"], "SecretAccessKey"
region_name=aws_region_name, ],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
iam_cache.set_cache(
key=iam_creds_cache_key,
value=json.dumps(iam_creds_dict),
ttl=3600 - 60,
) )
return session.get_credentials() session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None: elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",
@ -1542,7 +1604,7 @@ class BedrockConverseLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore client = _get_async_httpx_client(_params) # type: ignore
else: else:
client = client # type: ignore client = client # type: ignore
@ -1814,7 +1876,7 @@ class BedrockConverseLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore client = _get_httpx_client(_params) # type: ignore
else: else:
client = client client = client
try: try:
@ -1859,29 +1921,59 @@ class AWSEventStreamDecoder:
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
try:
text = "" text = ""
tool_str = "" tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None usage: Optional[ConverseTokenUsageBlock] = None
if "delta" in chunk_data:
index = int(chunk_data.get("contentBlockIndex", 0))
if "start" in chunk_data:
start_obj = ContentBlockStartEvent(**chunk_data["start"])
if (
start_obj is not None
and "toolUse" in start_obj
and start_obj["toolUse"] is not None
):
tool_use = {
"id": start_obj["toolUse"]["toolUseId"],
"type": "function",
"function": {
"name": start_obj["toolUse"]["name"],
"arguments": "",
},
}
elif "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj: if "text" in delta_obj:
text = delta_obj["text"] text = delta_obj["text"]
elif "toolUse" in delta_obj: elif "toolUse" in delta_obj:
tool_str = delta_obj["toolUse"]["input"] tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
}
elif "stopReason" in chunk_data: elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
is_finished = True
elif "usage" in chunk_data: elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
response = GenericStreamingChunk( response = GenericStreamingChunk(
text=text, text=text,
tool_str=tool_str, tool_use=tool_use,
is_finished=is_finished, is_finished=is_finished,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage, usage=usage,
index=index,
) )
return response return response
except Exception as e:
raise Exception("Received streaming error - {}".format(str(e)))
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = "" text = ""
@ -1890,12 +1982,16 @@ class AWSEventStreamDecoder:
if "outputText" in chunk_data: if "outputText" in chunk_data:
text = chunk_data["outputText"] text = chunk_data["outputText"]
# ai21 mapping # ai21 mapping
if "ai21" in self.model: # fake ai21 streaming elif "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True is_finished = True
finish_reason = "stop" finish_reason = "stop"
######## bedrock.anthropic mappings ############### ######## bedrock.anthropic mappings ###############
elif "delta" in chunk_data: elif (
"contentBlockIndex" in chunk_data
or "stopReason" in chunk_data
or "metrics" in chunk_data
):
return self.converse_chunk_parser(chunk_data=chunk_data) return self.converse_chunk_parser(chunk_data=chunk_data)
######## bedrock.mistral mappings ############### ######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data: elif "outputs" in chunk_data:
@ -1905,7 +2001,7 @@ class AWSEventStreamDecoder:
): ):
text = chunk_data["outputs"][0]["text"] text = chunk_data["outputs"][0]["text"]
stop_reason = chunk_data.get("stop_reason", None) stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None: if stop_reason is not None:
is_finished = True is_finished = True
finish_reason = stop_reason finish_reason = stop_reason
######## bedrock.cohere mappings ############### ######## bedrock.cohere mappings ###############
@ -1926,8 +2022,9 @@ class AWSEventStreamDecoder:
text=text, text=text,
is_finished=is_finished, is_finished=is_finished,
finish_reason=finish_reason, finish_reason=finish_reason,
tool_str="",
usage=None, usage=None,
index=0,
tool_use=None,
) )
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:

View file

@ -139,6 +139,7 @@ def process_response(
def convert_model_to_url(model: str, api_base: str): def convert_model_to_url(model: str, api_base: str):
user_id, app_id, model_id = model.split(".") user_id, app_id, model_id = model.split(".")
model_id = model_id.lower()
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs" return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
@ -171,19 +172,55 @@ async def async_completion(
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await async_handler.post( response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data) url=model, headers=headers, data=json.dumps(data)
) )
return process_response( logging_obj.post_call(
model=model, input=prompt,
prompt=prompt,
response=response,
model_response=model_response,
api_key=api_key, api_key=api_key,
data=data, original_response=response.text,
encoding=encoding, additional_args={"complete_input_dict": data},
logging_obj=logging_obj,
) )
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception:
raise ClarifaiError(
message=response.text, status_code=response.status_code, url=model
)
# print(completion_response)
try:
choices_list = []
for idx, item in enumerate(completion_response["outputs"]):
if len(item["data"]["text"]["raw"]) > 0:
message_obj = Message(content=item["data"]["text"]["raw"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason="stop",
index=idx + 1, # check
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model
)
# Calculate Usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
model_response["model"] = model
model_response["usage"] = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
def completion( def completion(
@ -241,7 +278,7 @@ def completion(
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"headers": headers, "headers": headers,
"api_base": api_base, "api_base": model,
}, },
) )
if acompletion == True: if acompletion == True:

View file

@ -12,6 +12,15 @@ class AsyncHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
concurrent_limit=1000, concurrent_limit=1000,
): ):
self.timeout = timeout
self.client = self.create_client(
timeout=timeout, concurrent_limit=concurrent_limit
)
def create_client(
self, timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int
) -> httpx.AsyncClient:
async_proxy_mounts = None async_proxy_mounts = None
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
http_proxy = os.getenv("HTTP_PROXY", None) http_proxy = os.getenv("HTTP_PROXY", None)
@ -39,7 +48,8 @@ class AsyncHTTPHandler:
if timeout is None: if timeout is None:
timeout = _DEFAULT_TIMEOUT timeout = _DEFAULT_TIMEOUT
# Create a client with a connection pool # Create a client with a connection pool
self.client = httpx.AsyncClient(
return httpx.AsyncClient(
timeout=timeout, timeout=timeout,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=concurrent_limit, max_connections=concurrent_limit,
@ -83,11 +93,48 @@ class AsyncHTTPHandler:
response = await self.client.send(req, stream=stream) response = await self.client.send(req, stream=stream)
response.raise_for_status() response.raise_for_status()
return response return response
except httpx.RemoteProtocolError:
# Retry the request with a new session if there is a connection error
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1)
try:
return await self.single_connection_post_request(
url=url,
client=new_client,
data=data,
json=json,
params=params,
headers=headers,
stream=stream,
)
finally:
await new_client.aclose()
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise e raise e
except Exception as e: except Exception as e:
raise e raise e
async def single_connection_post_request(
self,
url: str,
client: httpx.AsyncClient,
data: Optional[Union[dict, str]] = None, # type: ignore
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
"""
Making POST request for a single connection client.
Used for retrying connection client errors.
"""
req = client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = await client.send(req, stream=stream)
response.raise_for_status()
return response
def __del__(self) -> None: def __del__(self) -> None:
try: try:
asyncio.get_running_loop().create_task(self.close()) asyncio.get_running_loop().create_task(self.close())
@ -172,3 +219,60 @@ class HTTPHandler:
self.close() self.close()
except Exception: except Exception:
pass pass
def _get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler:
"""
Retrieves the async HTTP client from the cache
If not present, creates a new client
Caches the new client and returns it.
"""
_params_key_name = ""
if params is not None:
for key, value in params.items():
try:
_params_key_name += f"{key}_{value}"
except Exception:
pass
_cache_key_name = "async_httpx_client" + _params_key_name
if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name]
if params is not None:
_new_client = AsyncHTTPHandler(**params)
else:
_new_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
return _new_client
def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
"""
Retrieves the HTTP client from the cache
If not present, creates a new client
Caches the new client and returns it.
"""
_params_key_name = ""
if params is not None:
for key, value in params.items():
try:
_params_key_name += f"{key}_{value}"
except Exception:
pass
_cache_key_name = "httpx_client" + _params_key_name
if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name]
if params is not None:
_new_client = HTTPHandler(**params)
else:
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
return _new_client

View file

@ -10,10 +10,10 @@ from typing import Callable, Optional, List, Union, Tuple, Literal
from litellm.utils import ( from litellm.utils import (
ModelResponse, ModelResponse,
Usage, Usage,
map_finish_reason,
CustomStreamWrapper, CustomStreamWrapper,
EmbeddingResponse, EmbeddingResponse,
) )
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
@ -289,7 +289,7 @@ class DatabricksChatCompletion(BaseLLM):
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: litellm.utils.Logging, logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],

View file

@ -1,14 +1,22 @@
import types ####################################
import traceback ######### DEPRECATED FILE ##########
####################################
# logic moved to `vertex_httpx.py` #
import copy import copy
import time import time
import traceback
import types
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version from packaging.version import Version
import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.utils import Choices, Message, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
class GeminiError(Exception): class GeminiError(Exception):
@ -186,8 +194,8 @@ def completion(
if _system_instruction and len(system_prompt) > 0: if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt _params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params) _model = genai.GenerativeModel(**_params)
if stream == True: if stream is True:
if acompletion == True: if acompletion is True:
async def async_streaming(): async def async_streaming():
try: try:

View file

@ -1,33 +1,41 @@
import hashlib
import json
import time
import traceback
import types
from typing import ( from typing import (
Optional,
Union,
Any, Any,
BinaryIO, BinaryIO,
Literal, Callable,
Coroutine,
Iterable, Iterable,
Literal,
Optional,
Union,
) )
import hashlib
from typing_extensions import override, overload
from pydantic import BaseModel
import types, time, json, traceback
import httpx import httpx
from .base import BaseLLM
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
Usage,
TranscriptionResponse,
TextCompletionResponse,
)
from typing import Callable, Optional, Coroutine
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import *
import openai import openai
from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel
from typing_extensions import overload, override
import litellm
from litellm.types.utils import ProviderField
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
TextCompletionResponse,
TranscriptionResponse,
Usage,
convert_to_model_response_object,
)
from ..types.llms.openai import *
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class OpenAIError(Exception): class OpenAIError(Exception):
@ -164,6 +172,68 @@ class MistralConfig:
return optional_params return optional_params
class MistralEmbeddingConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
"""
def __init__(
self,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"encoding_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "encoding_format":
optional_params["encoding_format"] = value
return optional_params
class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
class DeepInfraConfig: class DeepInfraConfig:
""" """
Reference: https://deepinfra.com/docs/advanced/openai_api Reference: https://deepinfra.com/docs/advanced/openai_api
@ -243,8 +313,12 @@ class DeepInfraConfig:
] ]
def map_openai_params( def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str self,
): non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params() supported_openai_params = self.get_supported_openai_params()
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if ( if (
@ -253,7 +327,22 @@ class DeepInfraConfig:
and model == "mistralai/Mistral-7B-Instruct-v0.1" and model == "mistralai/Mistral-7B-Instruct-v0.1"
): # this model does no support temperature == 0 ): # this model does no support temperature == 0
value = 0.0001 # close to 0 value = 0.0001 # close to 0
if param == "tool_choice":
if (
value != "auto" and value != "none"
): # https://deepinfra.com/docs/advanced/function_calling
## UNSUPPORTED TOOL CHOICE VALUE
if litellm.drop_params is True or drop_params is True:
value = None
else:
raise litellm.utils.UnsupportedParamsError(
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
if param in supported_openai_params: if param in supported_openai_params:
if value is not None:
optional_params[param] = value optional_params[param] = value
return optional_params return optional_params

View file

@ -12,11 +12,11 @@ from typing import Callable, Optional, List, Literal, Union
from litellm.utils import ( from litellm.utils import (
ModelResponse, ModelResponse,
Usage, Usage,
map_finish_reason,
CustomStreamWrapper, CustomStreamWrapper,
Message, Message,
Choices, Choices,
) )
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
@ -198,7 +198,7 @@ class PredibaseChatCompletion(BaseLLM):
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: litellm.utils.Logging, logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],

View file

@ -1,24 +1,30 @@
import json
import re
import traceback
import uuid
import xml.etree.ElementTree as ET
from enum import Enum from enum import Enum
import requests, traceback
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import requests
from jinja2 import BaseLoader, Template, exceptions, meta
from jinja2.sandbox import ImmutableSandboxedEnvironment
import litellm import litellm
import litellm.types import litellm.types
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
import litellm.types.llms import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
import litellm.types.llms.vertex_ai import litellm.types.llms.vertex_ai
from litellm.types.completion import (
ChatCompletionFunctionMessageParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from litellm.types.llms.anthropic import *
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.utils import GenericImageParsingChunk
def default_pt(messages): def default_pt(messages):
@ -622,9 +628,10 @@ def construct_tool_use_system_prompt(
def convert_url_to_base64(url): def convert_url_to_base64(url):
import requests
import base64 import base64
import requests
for _ in range(3): for _ in range(3):
try: try:
response = requests.get(url) response = requests.get(url)
@ -654,7 +661,7 @@ def convert_url_to_base64(url):
raise Exception(f"Error: Unable to fetch image from URL. url={url}") raise Exception(f"Error: Unable to fetch image from URL. url={url}")
def convert_to_anthropic_image_obj(openai_image_url: str): def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
""" """
Input: Input:
"image_url": "data:image/jpeg;base64,{base64_image}", "image_url": "data:image/jpeg;base64,{base64_image}",
@ -675,11 +682,11 @@ def convert_to_anthropic_image_obj(openai_image_url: str):
# Infer image format from the URL # Infer image format from the URL
image_format = openai_image_url.split("data:image/")[1].split(";base64,")[0] image_format = openai_image_url.split("data:image/")[1].split(";base64,")[0]
return { return GenericImageParsingChunk(
"type": "base64", type="base64",
"media_type": f"image/{image_format}", media_type=f"image/{image_format}",
"data": base64_data, data=base64_data,
} )
except Exception as e: except Exception as e:
if "Error: Unable to fetch image from URL" in str(e): if "Error: Unable to fetch image from URL" in str(e):
raise e raise e
@ -1606,19 +1613,23 @@ def azure_text_pt(messages: list):
###### AMAZON BEDROCK ####### ###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock
from litellm.types.llms.bedrock import ImageSourceBlock as BedrockImageSourceBlock
from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock
from litellm.types.llms.bedrock import ( from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultBlock as BedrockToolResultBlock,
ToolConfigBlock as BedrockToolConfigBlock,
ToolUseBlock as BedrockToolUseBlock,
ImageSourceBlock as BedrockImageSourceBlock,
ImageBlock as BedrockImageBlock,
ContentBlock as BedrockContentBlock,
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
ToolSpecBlock as BedrockToolSpecBlock,
ToolBlock as BedrockToolBlock,
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock, ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
) )
from litellm.types.llms.bedrock import ToolConfigBlock as BedrockToolConfigBlock
from litellm.types.llms.bedrock import (
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
)
from litellm.types.llms.bedrock import ToolResultBlock as BedrockToolResultBlock
from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
)
from litellm.types.llms.bedrock import ToolSpecBlock as BedrockToolSpecBlock
from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
def get_image_details(image_url) -> Tuple[str, str]: def get_image_details(image_url) -> Tuple[str, str]:
@ -1655,7 +1666,8 @@ def get_image_details(image_url) -> Tuple[str, str]:
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
if "base64" in image_url: if "base64" in image_url:
# Case 1: Images with base64 encoding # Case 1: Images with base64 encoding
import base64, re import base64
import re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image> # base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",") image_metadata, img_without_base_64 = image_url.split(",")

View file

@ -0,0 +1,532 @@
# What is this?
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
from functools import partial
import os, types
import traceback
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List, Literal, Union
from litellm.utils import (
TextCompletionResponse,
Usage,
CustomStreamWrapper,
Message,
Choices,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.llms.databricks import GenericStreamingChunk
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class TextCompletionCodestralError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(
method="POST",
url="https://docs.codestral.com/user-guide/inference/rest_api",
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
async def make_call(
client: AsyncHTTPHandler,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise TextCompletionCodestralError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class MistralTextCompletionConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
"""
suffix: Optional[str] = None
temperature: Optional[int] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = None
random_seed: Optional[int] = None
stop: Optional[str] = None
def __init__(
self,
suffix: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stream: Optional[bool] = None,
random_seed: Optional[int] = None,
stop: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"suffix",
"temperature",
"top_p",
"max_tokens",
"stream",
"seed",
"stop",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "suffix":
optional_params["suffix"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "seed":
optional_params["random_seed"] = value
if param == "min_tokens":
optional_params["min_tokens"] = value
return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = None
logprobs = None
chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip()
if len(chunk_data) == 0 or chunk_data == "[DONE]":
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
chunk_data_dict = json.loads(chunk_data)
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
_choices = chunk_data_dict.get("choices", []) or []
_choice = _choices[0]
text = _choice.get("delta", {}).get("content", "")
if _choice.get("finish_reason") is not None:
is_finished = True
finish_reason = _choice.get("finish_reason")
logprobs = _choice.get("logprobs")
return GenericStreamingChunk(
text=text,
original_chunk=original_chunk,
is_finished=is_finished,
finish_reason=finish_reason,
logprobs=logprobs,
)
class CodestralTextCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _validate_environment(
self,
api_key: Optional[str],
user_headers: dict,
) -> dict:
if api_key is None:
raise ValueError(
"Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables"
)
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers}
return headers
def output_parser(self, generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = [
"<|assistant|>",
"<|system|>",
"<|user|>",
"<s>",
"</s>",
]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
def process_text_completion_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: TextCompletionResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> TextCompletionResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"codestral api: raw model_response: {response.text}")
## RESPONSE OBJECT
if response.status_code != 200:
raise TextCompletionCodestralError(
message=str(response.text),
status_code=response.status_code,
)
try:
completion_response = response.json()
except:
raise TextCompletionCodestralError(message=response.text, status_code=422)
_original_choices = completion_response.get("choices", [])
_choices: List[litellm.utils.TextChoices] = []
for choice in _original_choices:
# This is what 1 choice looks like from codestral API
# {
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "\n assert is_odd(1)\n assert",
# "tool_calls": null
# },
# "finish_reason": "length",
# "logprobs": null
# }
_finish_reason = None
_index = 0
_text = None
_logprobs = None
_choice_message = choice.get("message", {})
_choice = litellm.utils.TextChoices(
finish_reason=choice.get("finish_reason"),
index=choice.get("index"),
text=_choice_message.get("content"),
logprobs=choice.get("logprobs"),
)
_choices.append(_choice)
_response = litellm.TextCompletionResponse(
id=completion_response.get("id"),
choices=_choices,
created=completion_response.get("created"),
model=completion_response.get("model"),
usage=completion_response.get("usage"),
stream=False,
object=completion_response.get("object"),
)
return _response
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key: str,
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
## Load Config
config = litellm.MistralTextCompletionConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", False)
data = {
"prompt": prompt,
**optional_params,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
else:
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
### SYNC STREAMING
if stream is True:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=stream,
)
_response = CustomStreamWrapper(
response.iter_lines(),
model,
custom_llm_provider="codestral",
logging_obj=logging_obj,
)
return _response
### SYNC COMPLETION
else:
response = requests.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
)
return self.process_text_completion_response(
model=model,
response=response,
model_response=model_response,
stream=optional_params.get("stream", False),
logging_obj=logging_obj, # type: ignore
optional_params=optional_params,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
data: dict,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params=None,
logger_fn=None,
headers={},
) -> TextCompletionResponse:
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
except httpx.HTTPStatusError as e:
raise TextCompletionCodestralError(
status_code=e.response.status_code,
message="HTTPStatusError - {}".format(e.response.text),
)
except Exception as e:
raise TextCompletionCodestralError(
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
)
return self.process_text_completion_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
data: dict,
timeout: Union[float, httpx.Timeout],
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
) -> CustomStreamWrapper:
data["stream"] = True
streamwrapper = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="text-completion-codestral",
logging_obj=logging_obj,
)
return streamwrapper
def embedding(self, *args, **kwargs):
pass

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,6 @@ from enum import Enum
import requests, copy # type: ignore import requests, copy # type: ignore
import time import time
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler

View file

@ -1,18 +1,31 @@
import os, types import inspect
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List, Literal, Any import types
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason import uuid
import litellm, uuid from enum import Enum
import httpx, inspect # type: ignore from typing import Any, Callable, List, Literal, Optional, Union
from litellm.types.llms.vertex_ai import *
import httpx # type: ignore
import requests # type: ignore
from pydantic import BaseModel
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.prompt_templates.factory import ( from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result, convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke, convert_to_gemini_tool_call_invoke,
convert_to_gemini_tool_call_result,
) )
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type from litellm.types.files import (
get_file_mime_type_for_file_type,
get_file_type_from_extension,
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
class VertexAIError(Exception): class VertexAIError(Exception):
@ -267,28 +280,6 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
raise Exception(f"An exception occurs with this image - {str(e)}") raise Exception(f"An exception occurs with this image - {str(e)}")
def _load_image_from_url(image_url: str):
"""
Loads an image from a URL.
Args:
image_url (str): The URL of the image.
Returns:
Image: The loaded image.
"""
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(data=image_bytes)
def _convert_gemini_role(role: str) -> Literal["user", "model"]: def _convert_gemini_role(role: str) -> Literal["user", "model"]:
if role == "user": if role == "user":
return "user" return "user"
@ -316,28 +307,9 @@ def _process_gemini_image(image_url: str) -> PartType:
return PartType(file_data=file_data) return PartType(file_data=file_data)
# Direct links # Direct links
elif "https:/" in image_url: elif "https:/" in image_url or "base64" in image_url:
image = _load_image_from_url(image_url) image = convert_to_anthropic_image_obj(image_url)
_blob = BlobType(data=image.data, mime_type=image._mime_type) _blob = BlobType(data=image["data"], mime_type=image["media_type"])
return PartType(inline_data=_blob)
# Base64 encoding
elif "base64" in image_url:
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
else:
mime_type = "image/jpeg"
decoded_img = base64.b64decode(img_without_base_64)
_blob = BlobType(data=decoded_img, mime_type=mime_type)
return PartType(inline_data=_blob) return PartType(inline_data=_blob)
raise Exception("Invalid image received - {}".format(image_url)) raise Exception("Invalid image received - {}".format(image_url))
except Exception as e: except Exception as e:
@ -473,23 +445,25 @@ def completion(
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
) )
try: try:
import google.auth # type: ignore
import proto # type: ignore
from google.cloud import aiplatform # type: ignore
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types, # type: ignore
)
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Value # type: ignore
from vertexai.language_models import CodeGenerationModel, TextGenerationModel
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Part,
)
from vertexai.preview.language_models import ( from vertexai.preview.language_models import (
ChatModel, ChatModel,
CodeChatModel, CodeChatModel,
InputOutputTextPair, InputOutputTextPair,
) )
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
)
from google.cloud import aiplatform # type: ignore
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Value # type: ignore
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
import google.auth # type: ignore
import proto # type: ignore
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose( print_verbose(
@ -611,7 +585,7 @@ def completion(
llm_model = None llm_model = None
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
if acompletion == True: if acompletion is True:
data = { data = {
"llm_model": llm_model, "llm_model": llm_model,
"mode": mode, "mode": mode,
@ -643,7 +617,7 @@ def completion(
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
content = _gemini_convert_messages_with_history(messages=messages) content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
if stream == True: if stream is True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -1293,6 +1267,95 @@ async def async_streaming(
return streamwrapper return streamwrapper
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"dimensions",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["output_dimensionality"] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def embedding( def embedding(
model: str, model: str,
input: Union[list, str], input: Union[list, str],
@ -1316,8 +1379,8 @@ def embedding(
message="vertexai import failed please run `pip install google-cloud-aiplatform`", message="vertexai import failed please run `pip install google-cloud-aiplatform`",
) )
from vertexai.language_models import TextEmbeddingModel
import google.auth # type: ignore import google.auth # type: ignore
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
try: try:
@ -1347,6 +1410,16 @@ def embedding(
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
if optional_params is not None and isinstance(optional_params, dict):
if optional_params.get("task_type") or optional_params.get("title"):
# if user passed task_type or title, cast to TextEmbeddingInput
_task_type = optional_params.pop("task_type", None)
_title = optional_params.pop("title", None)
input = [
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
for x in input
]
try: try:
llm_model = TextEmbeddingModel.from_pretrained(model) llm_model = TextEmbeddingModel.from_pretrained(model)
except Exception as e: except Exception as e:
@ -1363,7 +1436,8 @@ def embedding(
encoding=encoding, encoding=encoding,
) )
request_str = f"""embeddings = llm_model.get_embeddings({input})""" _input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL ## LOGGING PRE-CALL
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -1375,7 +1449,7 @@ def embedding(
) )
try: try:
embeddings = llm_model.get_embeddings(input) embeddings = llm_model.get_embeddings(**_input_dict)
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
@ -1383,6 +1457,7 @@ def embedding(
logging_obj.post_call(input=input, api_key=None, original_response=embeddings) logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary
embedding_response = [] embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
embedding_response.append( embedding_response.append(
{ {
@ -1391,14 +1466,10 @@ def embedding(
"embedding": embedding.values, "embedding": embedding.values,
} }
) )
input_tokens += embedding.statistics.token_count
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
model_response["model"] = model model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
@ -1420,7 +1491,8 @@ async def async_embedding(
""" """
Async embedding implementation Async embedding implementation
""" """
request_str = f"""embeddings = llm_model.get_embeddings({input})""" _input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL ## LOGGING PRE-CALL
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -1432,7 +1504,7 @@ async def async_embedding(
) )
try: try:
embeddings = await client.get_embeddings_async(input) embeddings = await client.get_embeddings_async(**_input_dict)
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
@ -1440,6 +1512,7 @@ async def async_embedding(
logging_obj.post_call(input=input, api_key=None, original_response=embeddings) logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary
embedding_response = [] embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
embedding_response.append( embedding_response.append(
{ {
@ -1448,18 +1521,13 @@ async def async_embedding(
"embedding": embedding.values, "embedding": embedding.values,
} }
) )
input_tokens += embedding.statistics.token_count
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
model_response["model"] = model model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response

View file

@ -6,7 +6,8 @@ from enum import Enum
import requests, copy # type: ignore import requests, copy # type: ignore
import time, uuid import time, uuid
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .prompt_templates.factory import ( from .prompt_templates.factory import (

View file

@ -1,16 +1,325 @@
import os, types # What is this?
## httpx client for vertex ai calls
## Initial implementation - covers gemini + image gen calls
import inspect
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List, Any, Tuple import types
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason import uuid
import litellm, uuid from enum import Enum
import httpx, inspect # type: ignore from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import ijson
import requests # type: ignore
import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import convert_url_to_base64
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.llms.vertex_ai import (
ContentType,
FunctionCallingConfig,
FunctionDeclaration,
GenerateContentResponseBody,
GenerationConfig,
PartType,
RequestBody,
SafetSettingsConfig,
SystemInstructions,
ToolConfig,
Tools,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM from .base import BaseLLM
class VertexGeminiConfig:
"""
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
- `temperature` (float): This controls the degree of randomness in token selection.
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
Note: Please make sure to modify the default parameters as required for your use case.
"""
temperature: Optional[float] = None
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
response_mime_type: Optional[str] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
response_mime_type: Optional[str] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
"response_format",
"n",
"stop",
]
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict]
) -> Optional[ToolConfig]:
if tool_choice == "none":
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE"))
elif tool_choice == "required":
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY"))
elif tool_choice == "auto":
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO"))
elif isinstance(tool_choice, dict):
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
name = tool_choice.get("function", {}).get("name", "")
return ToolConfig(
functionCallingConfig=FunctionCallingConfig(
mode="ANY", allowed_function_names=[name]
)
)
else:
raise litellm.utils.UnsupportedParamsError(
message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
def map_openai_params(
self,
model: str,
non_default_params: dict,
optional_params: dict,
):
for param, value in non_default_params.items():
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if (
param == "stream" and value is True
): # sending stream = False, can cause it to get passed unchecked and raise issues
optional_params["stream"] = value
if param == "n":
optional_params["candidate_count"] = value
if param == "stop":
if isinstance(value, str):
optional_params["stop_sequences"] = [value]
elif isinstance(value, list):
optional_params["stop_sequences"] = value
if param == "max_tokens":
optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object": # type: ignore
optional_params["response_mime_type"] = "application/json"
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "tools" and isinstance(value, list):
gtool_func_declarations = []
for tool in value:
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
Tools(function_declarations=gtool_func_declarations)
]
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):
_tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value # type: ignore
)
if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
"""
return [
"europe-central2",
"europe-north1",
"europe-southwest1",
"europe-west1",
"europe-west2",
"europe-west3",
"europe-west4",
"europe-west6",
"europe-west8",
"europe-west9",
]
def get_flagged_finish_reasons(self) -> Dict[str, str]:
"""
Return Dictionary of finish reasons which indicate response was flagged
and what it means
"""
return {
"SAFETY": "The token generation was stopped as the response was flagged for safety reasons. NOTE: When streaming the Candidate.content will be empty if content filters blocked the output.",
"RECITATION": "The token generation was stopped as the response was flagged for unauthorized citations.",
"BLOCKLIST": "The token generation was stopped as the response was flagged for the terms which are included from the terminology blocklist.",
"PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for the prohibited contents.",
"SPII": "The token generation was stopped as the response was flagged for Sensitive Personally Identifiable Information (SPII) contents.",
}
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = AsyncHTTPHandler() # Create a new client if none provided
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise VertexAIError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_bytes(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
class VertexAIError(Exception): class VertexAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -33,16 +342,159 @@ class VertexLLM(BaseLLM):
self.project_id: Optional[str] = None self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None self.async_handler: Optional[AsyncHTTPHandler] = None
def load_auth(self) -> Tuple[Any, str]: def _process_response(
from google.auth.transport.requests import Request # type: ignore[import-untyped] self,
from google.auth.credentials import Credentials # type: ignore[import-untyped] model: str,
import google.auth as google_auth response: httpx.Response,
model_response: ModelResponse,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
credentials, project_id = google_auth.default( ## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
except Exception as e:
raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
## CHECK IF RESPONSE FLAGGED
if len(completion_response["candidates"]) > 0:
content_policy_violations = (
VertexGeminiConfig().get_flagged_finish_reasons()
)
if (
"finishReason" in completion_response["candidates"][0]
and completion_response["candidates"][0]["finishReason"]
in content_policy_violations.keys()
):
## CONTENT POLICY VIOLATION ERROR
raise VertexAIError(
status_code=400,
message="The response was blocked. Reason={}. Raw Response={}".format(
content_policy_violations[
completion_response["candidates"][0]["finishReason"]
],
completion_response,
),
)
model_response.choices = [] # type: ignore
## GET MODEL ##
model_response.model = model
try:
## GET TEXT ##
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant"
}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
for idx, candidate in enumerate(completion_response["candidates"]):
if "content" not in candidate:
continue
if "text" in candidate["content"]["parts"][0]:
content_str = candidate["content"]["parts"][0]["text"]
if "functionCall" in candidate["content"]["parts"][0]:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=candidate["content"]["parts"][0]["functionCall"]["name"],
arguments=json.dumps(
candidate["content"]["parts"][0]["functionCall"]["args"]
),
)
_tool_response_chunk = ChatCompletionToolCallChunk(
id=f"call_{str(uuid.uuid4())}",
type="function",
function=_function_chunk,
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = content_str
chat_completion_message["tool_calls"] = tools
choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"),
index=candidate.get("index", idx),
message=chat_completion_message, # type: ignore
logprobs=None,
enhancements=None,
)
model_response.choices.append(choice)
## GET USAGE ##
usage = litellm.Usage(
prompt_tokens=completion_response["usageMetadata"]["promptTokenCount"],
completion_tokens=completion_response["usageMetadata"][
"candidatesTokenCount"
],
total_tokens=completion_response["usageMetadata"]["totalTokenCount"],
)
setattr(model_response, "usage", usage)
except Exception as e:
raise VertexAIError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
completion_response, str(e)
),
status_code=422,
)
return model_response
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
return vertex_region or "us-central1"
def load_auth(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[Any, str]:
import google.auth as google_auth
from google.auth.credentials import Credentials # type: ignore[import-untyped]
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
if credentials is not None and isinstance(credentials, str):
import google.oauth2.service_account
json_obj = json.loads(credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"], scopes=["https://www.googleapis.com/auth/cloud-platform"],
) )
credentials.refresh(Request()) if project_id is None:
project_id = creds.project_id
else:
creds, project_id = google_auth.default(
quota_project_id=project_id,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
creds.refresh(Request())
if not project_id: if not project_id:
raise ValueError("Could not resolve project_id") raise ValueError("Could not resolve project_id")
@ -52,38 +504,364 @@ class VertexLLM(BaseLLM):
f"Expected project_id to be a str but got {type(project_id)}" f"Expected project_id to be a str but got {type(project_id)}"
) )
return credentials, project_id return creds, project_id
def refresh_auth(self, credentials: Any) -> None: def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import Request # type: ignore[import-untyped] from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
credentials.refresh(Request()) credentials.refresh(Request())
def _prepare_request(self, request: httpx.Request) -> None: def _ensure_access_token(
access_token = self._ensure_access_token() self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[str, str]:
if request.headers.get("Authorization"): """
# already authenticated, nothing for us to do Returns auth token and project id
return """
if self.access_token is not None and self.project_id is not None:
request.headers["Authorization"] = f"Bearer {access_token}" return self.access_token, self.project_id
def _ensure_access_token(self) -> str:
if self.access_token is not None:
return self.access_token
if not self._credentials: if not self._credentials:
self._credentials, project_id = self.load_auth() self._credentials, project_id = self.load_auth(
credentials=credentials, project_id=project_id
)
if not self.project_id: if not self.project_id:
self.project_id = project_id self.project_id = project_id
else: else:
self.refresh_auth(self._credentials) self.refresh_auth(self._credentials)
if not self._credentials.token: if not self.project_id:
self.project_id = self._credentials.project_id
if not self.project_id:
raise ValueError("Could not resolve project_id")
if not self._credentials or not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment") raise RuntimeError("Could not resolve API token from the environment")
assert isinstance(self._credentials.token, str) return self._credentials.token, self.project_id
return self._credentials.token
def _get_token_and_url(
self,
model: str,
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
)
else:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
return auth_header, url
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="vertex_ai_beta",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
return self._process_response(
model=model,
response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
gemini_api_key: Optional[str],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=stream,
custom_llm_provider=custom_llm_provider,
)
## TRANSFORMATION ##
try:
supports_system_message = litellm.supports_system_messages(
model=model, custom_llm_provider=custom_llm_provider
)
except Exception as e:
verbose_logger.error(
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
str(e)
)
)
supports_system_message = False
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[PartType] = []
if supports_system_message is True:
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = PartType(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
content = _gemini_convert_messages_with_history(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
"safety_settings", None
) # type: ignore
generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params
)
data = RequestBody(contents=content)
if len(system_content_blocks) > 0:
system_instructions = SystemInstructions(parts=system_content_blocks)
data["system_instruction"] = system_instructions
if tools is not None:
data["tools"] = tools
if tool_choice is not None:
data["toolConfig"] = tool_choice
if safety_settings is not None:
data["safetySettings"] = safety_settings
if generation_config is not None:
data["generationConfig"] = generation_config
headers = {
"Content-Type": "application/json; charset=utf-8",
}
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": url,
"headers": headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
### ASYNC STREAMING
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=json.dumps(data), # type: ignore
api_base=url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client, # type: ignore
)
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data, # type: ignore
api_base=url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client, # type: ignore
)
## SYNC STREAMING CALL ##
if stream is not None and stream is True:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=url,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="vertex_ai_beta",
logging_obj=logging_obj,
)
return streaming_response
## COMPLETION CALL ##
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
try:
response = client.post(url=url, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
return self._process_response(
model=model,
response=response,
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data, # type: ignore
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
def image_generation( def image_generation(
self, self,
@ -163,7 +941,7 @@ class VertexLLM(BaseLLM):
} \ } \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
""" """
auth_header = self._ensure_access_token() auth_header, _ = self._ensure_access_token(credentials=None, project_id=None)
optional_params = optional_params or { optional_params = optional_params or {
"sampleCount": 1 "sampleCount": 1
} # default optional params } # default optional params
@ -222,3 +1000,110 @@ class VertexLLM(BaseLLM):
model_response.data = _response_data model_response.data = _response_data
return model_response return model_response
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
if sync_stream:
self.response_iterator = iter(self.streaming_response)
self.events = ijson.sendable_list()
self.coro = ijson.items_coro(self.events, "item")
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
gemini_chunk = processed_chunk["candidates"][0]
if (
"content" in gemini_chunk
and "text" in gemini_chunk["content"]["parts"][0]
):
text = gemini_chunk["content"]["parts"][0]["text"]
if "finishReason" in gemini_chunk:
finish_reason = map_finish_reason(
finish_reason=gemini_chunk["finishReason"]
)
is_finished = True
if "usageMetadata" in processed_chunk:
usage = ChatCompletionUsageBlock(
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
completion_tokens=processed_chunk["usageMetadata"][
"candidatesTokenCount"
],
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
)
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
self.coro.send(chunk)
if self.events:
event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
self.coro.send(chunk)
if self.events:
event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")

Some files were not shown because too many files have changed in this diff Show more