Merge branch 'main' into fix/reset-end-user-budget-by-duration

This commit is contained in:
Laurien 2025-02-12 08:24:06 +01:00 committed by GitHub
commit 9f12cba8bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
231 changed files with 6692 additions and 1224 deletions

View file

@ -72,6 +72,7 @@ jobs:
pip install "jsonschema==4.22.0" pip install "jsonschema==4.22.0"
pip install "pytest-xdist==3.6.1" pip install "pytest-xdist==3.6.1"
pip install "websockets==10.4" pip install "websockets==10.4"
pip uninstall posthog -y
- save_cache: - save_cache:
paths: paths:
- ./venv - ./venv
@ -1517,6 +1518,117 @@ jobs:
- store_test_results: - store_test_results:
path: test-results path: test-results
proxy_multi_instance_tests:
machine:
image: ubuntu-2204:2023.10.1
resource_class: xlarge
working_directory: ~/project
steps:
- checkout
- run:
name: Install Docker CLI (In case it's not already installed)
command: |
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io
- run:
name: Install Python 3.9
command: |
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh --output miniconda.sh
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
conda init bash
source ~/.bashrc
conda create -n myenv python=3.9 -y
conda activate myenv
python --version
- run:
name: Install Dependencies
command: |
pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1"
pip install aiohttp
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-mock==3.12.0"
pip install "pytest-asyncio==0.21.1"
- run:
name: Build Docker image
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
- run:
name: Run Docker container 1
# intentionally give bad redis credentials here
# the OTEL test - should get this as a trace
command: |
docker run -d \
-p 4000:4000 \
-e DATABASE_URL=$PROXY_DATABASE_URL \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e LITELLM_MASTER_KEY="sk-1234" \
-e LITELLM_LICENSE=$LITELLM_LICENSE \
-e USE_DDTRACE=True \
-e DD_API_KEY=$DD_API_KEY \
-e DD_SITE=$DD_SITE \
--name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/multi_instance_simple_config.yaml:/app/config.yaml \
my-app:latest \
--config /app/config.yaml \
--port 4000 \
--detailed_debug \
- run:
name: Run Docker container 2
command: |
docker run -d \
-p 4001:4001 \
-e DATABASE_URL=$PROXY_DATABASE_URL \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e LITELLM_MASTER_KEY="sk-1234" \
-e LITELLM_LICENSE=$LITELLM_LICENSE \
-e USE_DDTRACE=True \
-e DD_API_KEY=$DD_API_KEY \
-e DD_SITE=$DD_SITE \
--name my-app-2 \
-v $(pwd)/litellm/proxy/example_config_yaml/multi_instance_simple_config.yaml:/app/config.yaml \
my-app:latest \
--config /app/config.yaml \
--port 4001 \
--detailed_debug
- run:
name: Install curl and dockerize
command: |
sudo apt-get update
sudo apt-get install -y curl
sudo wget https://github.com/jwilder/dockerize/releases/download/v0.6.1/dockerize-linux-amd64-v0.6.1.tar.gz
sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-v0.6.1.tar.gz
sudo rm dockerize-linux-amd64-v0.6.1.tar.gz
- run:
name: Start outputting logs
command: docker logs -f my-app
background: true
- run:
name: Wait for instance 1 to be ready
command: dockerize -wait http://localhost:4000 -timeout 5m
- run:
name: Wait for instance 2 to be ready
command: dockerize -wait http://localhost:4001 -timeout 5m
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/multi_instance_e2e_tests -x --junitxml=test-results/junit.xml --durations=5
no_output_timeout:
120m
# Clean up first container
# Store test results
- store_test_results:
path: test-results
proxy_store_model_in_db_tests: proxy_store_model_in_db_tests:
machine: machine:
image: ubuntu-2204:2023.10.1 image: ubuntu-2204:2023.10.1
@ -1552,6 +1664,7 @@ jobs:
pip install "pytest-retry==1.6.3" pip install "pytest-retry==1.6.3"
pip install "pytest-mock==3.12.0" pip install "pytest-mock==3.12.0"
pip install "pytest-asyncio==0.21.1" pip install "pytest-asyncio==0.21.1"
pip install "assemblyai==0.37.0"
- run: - run:
name: Build Docker image name: Build Docker image
command: docker build -t my-app:latest -f ./docker/Dockerfile.database . command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
@ -2171,6 +2284,12 @@ workflows:
only: only:
- main - main
- /litellm_.*/ - /litellm_.*/
- proxy_multi_instance_tests:
filters:
branches:
only:
- main
- /litellm_.*/
- proxy_store_model_in_db_tests: - proxy_store_model_in_db_tests:
filters: filters:
branches: branches:
@ -2302,6 +2421,7 @@ workflows:
- installing_litellm_on_python - installing_litellm_on_python
- installing_litellm_on_python_3_13 - installing_litellm_on_python_3_13
- proxy_logging_guardrails_model_info_tests - proxy_logging_guardrails_model_info_tests
- proxy_multi_instance_tests
- proxy_store_model_in_db_tests - proxy_store_model_in_db_tests
- proxy_build_from_pip_tests - proxy_build_from_pip_tests
- proxy_pass_through_endpoint_tests - proxy_pass_through_endpoint_tests

View file

@ -20,3 +20,8 @@ REPLICATE_API_TOKEN = ""
ANTHROPIC_API_KEY = "" ANTHROPIC_API_KEY = ""
# Infisical # Infisical
INFISICAL_TOKEN = "" INFISICAL_TOKEN = ""
# Development Configs
LITELLM_MASTER_KEY = "sk-1234"
DATABASE_URL = "postgresql://llmproxy:dbpassword9090@db:5432/litellm"
STORE_MODEL_IN_DB = "True"

View file

@ -52,6 +52,39 @@ def interpret_results(csv_file):
return markdown_table return markdown_table
def _get_docker_run_command_stable_release(release_version):
return f"""
\n\n
## Docker Run LiteLLM Proxy
```
docker run \\
-e STORE_MODEL_IN_DB=True \\
-p 4000:4000 \\
ghcr.io/berriai/litellm_stable_release_branch-{release_version}
"""
def _get_docker_run_command(release_version):
return f"""
\n\n
## Docker Run LiteLLM Proxy
```
docker run \\
-e STORE_MODEL_IN_DB=True \\
-p 4000:4000 \\
ghcr.io/berriai/litellm:main-{release_version}
"""
def get_docker_run_command(release_version):
if "stable" in release_version:
return _get_docker_run_command_stable_release(release_version)
else:
return _get_docker_run_command(release_version)
if __name__ == "__main__": if __name__ == "__main__":
csv_file = "load_test_stats.csv" # Change this to the path of your CSV file csv_file = "load_test_stats.csv" # Change this to the path of your CSV file
markdown_table = interpret_results(csv_file) markdown_table = interpret_results(csv_file)
@ -79,17 +112,7 @@ if __name__ == "__main__":
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results") start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
existing_release_body = latest_release.body[:start_index] existing_release_body = latest_release.body[:start_index]
docker_run_command = f""" docker_run_command = get_docker_run_command(release_version)
\n\n
## Docker Run LiteLLM Proxy
```
docker run \\
-e STORE_MODEL_IN_DB=True \\
-p 4000:4000 \\
ghcr.io/berriai/litellm:main-{release_version}
```
"""
print("docker run command: ", docker_run_command) print("docker run command: ", docker_run_command)
new_release_body = ( new_release_body = (

View file

@ -451,3 +451,20 @@ If you have suggestions on how to improve the code quality feel free to open an
<a href="https://github.com/BerriAI/litellm/graphs/contributors"> <a href="https://github.com/BerriAI/litellm/graphs/contributors">
<img src="https://contrib.rocks/image?repo=BerriAI/litellm" /> <img src="https://contrib.rocks/image?repo=BerriAI/litellm" />
</a> </a>
## Run in Developer mode
### Services
1. Setup .env file in root
2. Run dependant services `docker-compose up db prometheus`
### Backend
1. (In root) create virtual environment `python -m venv .venv`
2. Activate virtual environment `source .venv/bin/activate`
3. Install dependencies `pip install -e ".[all]"`
4. Start proxy backend `uvicorn litellm.proxy.proxy_server:app --host localhost --port 4000 --reload`
### Frontend
1. Navigate to `ui/litellm-dashboard`
2. Install dependencies `npm install`
3. Run `npm run dev` to start the dashboard

View file

@ -0,0 +1,172 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "4FbDOmcj2VkM"
},
"source": [
"## Use LiteLLM with Arize\n",
"https://docs.litellm.ai/docs/observability/arize_integration\n",
"\n",
"This method uses the litellm proxy to send the data to Arize. The callback is set in the litellm config below, instead of using OpenInference tracing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "21W8Woog26Ns"
},
"source": [
"## Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "xrjKLBxhxu2L"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: litellm in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (1.54.1)\n",
"Requirement already satisfied: aiohttp in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (3.11.10)\n",
"Requirement already satisfied: click in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (8.1.7)\n",
"Requirement already satisfied: httpx<0.28.0,>=0.23.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.27.2)\n",
"Requirement already satisfied: importlib-metadata>=6.8.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (8.5.0)\n",
"Requirement already satisfied: jinja2<4.0.0,>=3.1.2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (3.1.4)\n",
"Requirement already satisfied: jsonschema<5.0.0,>=4.22.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (4.23.0)\n",
"Requirement already satisfied: openai>=1.55.3 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (1.57.1)\n",
"Requirement already satisfied: pydantic<3.0.0,>=2.0.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (2.10.3)\n",
"Requirement already satisfied: python-dotenv>=0.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (1.0.1)\n",
"Requirement already satisfied: requests<3.0.0,>=2.31.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (2.32.3)\n",
"Requirement already satisfied: tiktoken>=0.7.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.7.0)\n",
"Requirement already satisfied: tokenizers in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.21.0)\n",
"Requirement already satisfied: anyio in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (4.7.0)\n",
"Requirement already satisfied: certifi in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (2024.8.30)\n",
"Requirement already satisfied: httpcore==1.* in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (1.0.7)\n",
"Requirement already satisfied: idna in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (3.10)\n",
"Requirement already satisfied: sniffio in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (1.3.1)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.23.0->litellm) (0.14.0)\n",
"Requirement already satisfied: zipp>=3.20 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from importlib-metadata>=6.8.0->litellm) (3.21.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jinja2<4.0.0,>=3.1.2->litellm) (3.0.2)\n",
"Requirement already satisfied: attrs>=22.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (24.2.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (2024.10.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (0.35.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (0.22.3)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (1.9.0)\n",
"Requirement already satisfied: jiter<1,>=0.4.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (0.6.1)\n",
"Requirement already satisfied: tqdm>4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (4.67.1)\n",
"Requirement already satisfied: typing-extensions<5,>=4.11 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (4.12.2)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from pydantic<3.0.0,>=2.0.0->litellm) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.27.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from pydantic<3.0.0,>=2.0.0->litellm) (2.27.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.31.0->litellm) (3.4.0)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.31.0->litellm) (2.0.7)\n",
"Requirement already satisfied: regex>=2022.1.18 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from tiktoken>=0.7.0->litellm) (2024.11.6)\n",
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (2.4.4)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.3.1)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.5.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (6.1.0)\n",
"Requirement already satisfied: propcache>=0.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (0.2.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.18.3)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from tokenizers->litellm) (0.26.5)\n",
"Requirement already satisfied: filelock in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (3.16.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (2024.10.0)\n",
"Requirement already satisfied: packaging>=20.9 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (24.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (6.0.2)\n"
]
}
],
"source": [
"!pip install litellm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jHEu-TjZ29PJ"
},
"source": [
"## Set Env Variables"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "QWd9rTysxsWO"
},
"outputs": [],
"source": [
"import litellm\n",
"import os\n",
"from getpass import getpass\n",
"\n",
"os.environ[\"ARIZE_SPACE_KEY\"] = getpass(\"Enter your Arize space key: \")\n",
"os.environ[\"ARIZE_API_KEY\"] = getpass(\"Enter your Arize API key: \")\n",
"os.environ['OPENAI_API_KEY']= getpass(\"Enter your OpenAI API key: \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's run a completion call and see the traces in Arize"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello! Nice to meet you, OpenAI. How can I assist you today?\n"
]
}
],
"source": [
"# set arize as a callback, litellm will send the data to arize\n",
"litellm.callbacks = [\"arize\"]\n",
" \n",
"# openai call\n",
"response = litellm.completion(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Hi 👋 - i'm openai\"}\n",
" ]\n",
")\n",
"print(response.choices[0].message.content)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -29,6 +29,8 @@ services:
POSTGRES_DB: litellm POSTGRES_DB: litellm
POSTGRES_USER: llmproxy POSTGRES_USER: llmproxy
POSTGRES_PASSWORD: dbpassword9090 POSTGRES_PASSWORD: dbpassword9090
ports:
- "5432:5432"
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -d litellm -U llmproxy"] test: ["CMD-SHELL", "pg_isready -d litellm -U llmproxy"]
interval: 1s interval: 1s

View file

@ -1,5 +1,5 @@
# Local Debugging # Local Debugging
There's 2 ways to do local debugging - `litellm.set_verbose=True` and by passing in a custom function `completion(...logger_fn=<your_local_function>)`. Warning: Make sure to not use `set_verbose` in production. It logs API keys, which might end up in log files. There's 2 ways to do local debugging - `litellm._turn_on_debug()` and by passing in a custom function `completion(...logger_fn=<your_local_function>)`. Warning: Make sure to not use `_turn_on_debug()` in production. It logs API keys, which might end up in log files.
## Set Verbose ## Set Verbose
@ -8,7 +8,7 @@ This is good for getting print statements for everything litellm is doing.
import litellm import litellm
from litellm import completion from litellm import completion
litellm.set_verbose=True # 👈 this is the 1-line change you need to make litellm._turn_on_debug() # 👈 this is the 1-line change you need to make
## set ENV variables ## set ENV variables
os.environ["OPENAI_API_KEY"] = "openai key" os.environ["OPENAI_API_KEY"] = "openai key"

View file

@ -28,7 +28,7 @@ import litellm
import os import os
os.environ["ARIZE_SPACE_KEY"] = "" os.environ["ARIZE_SPACE_KEY"] = ""
os.environ["ARIZE_API_KEY"] = "" # defaults to litellm-completion os.environ["ARIZE_API_KEY"] = ""
# LLM API Keys # LLM API Keys
os.environ['OPENAI_API_KEY']="" os.environ['OPENAI_API_KEY']=""

View file

@ -12,6 +12,9 @@ Supports **ALL** Assembly AI Endpoints
[**See All Assembly AI Endpoints**](https://www.assemblyai.com/docs/api-reference) [**See All Assembly AI Endpoints**](https://www.assemblyai.com/docs/api-reference)
<iframe width="840" height="500" src="https://www.loom.com/embed/aac3f4d74592448992254bfa79b9f62d?sid=267cd0ab-d92b-42fa-b97a-9f385ef8930c" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
## Quick Start ## Quick Start
Let's call the Assembly AI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts) Let's call the Assembly AI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts)
@ -35,6 +38,8 @@ litellm
Let's call the Assembly AI `/v2/transcripts` endpoint Let's call the Assembly AI `/v2/transcripts` endpoint
```python ```python
import assemblyai as aai
LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key> LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key>
LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/assemblyai" # <your-proxy-base-url>/assemblyai LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/assemblyai" # <your-proxy-base-url>/assemblyai
@ -53,3 +58,28 @@ print(transcript)
print(transcript.id) print(transcript.id)
``` ```
## Calling Assembly AI EU endpoints
If you want to send your request to the Assembly AI EU endpoint, you can do so by setting the `LITELLM_PROXY_BASE_URL` to `<your-proxy-base-url>/eu.assemblyai`
```python
import assemblyai as aai
LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key>
LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/eu.assemblyai" # <your-proxy-base-url>/eu.assemblyai
aai.settings.api_key = f"Bearer {LITELLM_VIRTUAL_KEY}"
aai.settings.base_url = LITELLM_PROXY_BASE_URL
# URL of the file to transcribe
FILE_URL = "https://assembly.ai/wildfires.mp3"
# You can also transcribe a local file by passing in a file path
# FILE_URL = './path/to/file.mp3'
transcriber = aai.Transcriber()
transcript = transcriber.transcribe(FILE_URL)
print(transcript)
print(transcript.id)
```

View file

@ -987,6 +987,106 @@ curl http://0.0.0.0:4000/v1/chat/completions \
</TabItem> </TabItem>
</Tabs> </Tabs>
## [BETA] Citations API
Pass `citations: {"enabled": true}` to Anthropic, to get citations on your document responses.
Note: This interface is in BETA. If you have feedback on how citations should be returned, please [tell us here](https://github.com/BerriAI/litellm/issues/7970#issuecomment-2644437943)
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
resp = completion(
model="claude-3-5-sonnet-20241022",
messages=[
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"title": "My Document",
"context": "This is a trustworthy document.",
"citations": {"enabled": True},
},
{
"type": "text",
"text": "What color is the grass and sky?",
},
],
}
],
)
citations = resp.choices[0].message.provider_specific_fields["citations"]
assert citations is not None
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-5-sonnet-20241022
api_key: os.environ/ANTHROPIC_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "anthropic-claude",
"messages": [
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"title": "My Document",
"context": "This is a trustworthy document.",
"citations": {"enabled": True},
},
{
"type": "text",
"text": "What color is the grass and sky?",
},
],
}
]
}'
```
</TabItem>
</Tabs>
## Usage - passing 'user_id' to Anthropic ## Usage - passing 'user_id' to Anthropic
LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param. LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.

View file

@ -688,7 +688,9 @@ response = litellm.completion(
|-----------------------|--------------------------------------------------------|--------------------------------| |-----------------------|--------------------------------------------------------|--------------------------------|
| gemini-pro | `completion(model='gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-pro | `completion(model='gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro-latest | `completion(model='gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-1.5-pro-latest | `completion(model='gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-pro-vision | `completion(model='gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-2.0-flash | `completion(model='gemini/gemini-2.0-flash', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-2.0-flash-exp | `completion(model='gemini/gemini-2.0-flash-exp', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-2.0-flash-lite-preview-02-05 | `completion(model='gemini/gemini-2.0-flash-lite-preview-02-05', messages)` | `os.environ['GEMINI_API_KEY']` |

View file

@ -37,7 +37,7 @@ guardrails:
- guardrail_name: aim-protected-app - guardrail_name: aim-protected-app
litellm_params: litellm_params:
guardrail: aim guardrail: aim
mode: pre_call mode: pre_call # 'during_call' is also available
api_key: os.environ/AIM_API_KEY api_key: os.environ/AIM_API_KEY
api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost
``` ```

View file

@ -166,7 +166,7 @@ response = client.chat.completions.create(
{"role": "user", "content": "what color is red"} {"role": "user", "content": "what color is red"}
], ],
logit_bias={12481: 100}, logit_bias={12481: 100},
timeout=1 extra_body={"timeout": 1} # 👈 KEY CHANGE
) )
print(response) print(response)

View file

@ -163,10 +163,12 @@ scope: "litellm-proxy-admin ..."
```yaml ```yaml
general_settings: general_settings:
master_key: sk-1234 enable_jwt_auth: True
litellm_jwtauth: litellm_jwtauth:
user_id_jwt_field: "sub" user_id_jwt_field: "sub"
team_ids_jwt_field: "groups" team_ids_jwt_field: "groups"
user_id_upsert: true # add user_id to the db if they don't exist
enforce_team_based_model_access: true # don't allow users to access models unless the team has access
``` ```
This is assuming your token looks like this: This is assuming your token looks like this:
@ -370,4 +372,68 @@ Supported internal roles:
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'. - `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token. - `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch) ### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
## [BETA] Control Model Access with Scopes
Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control.
### 1. Setup config.yaml with scope mappings.
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-5-sonnet
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: gpt-3.5-turbo-testing
litellm_params:
model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id
team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db
scope_mappings:
- scope: litellm.api.consumer
models: ["anthropic-claude"]
- scope: litellm.api.gpt_3_5_turbo
models: ["gpt-3.5-turbo-testing"]
enforce_scope_based_access: true # 👈 enforce scope-based access control
enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy.
```
#### Scope Mapping Spec
- `scope`: The scope to be used for the JWT token.
- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported.
### 2. Create a JWT with the correct scopes.
Expected Token:
```
{
"scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"]
}
```
### 3. Test the flow.
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer eyJhbGci...' \
-d '{
"model": "gpt-3.5-turbo-testing",
"messages": [
{
"role": "user",
"content": "Hey, how'\''s it going 1234?"
}
]
}'
```

View file

@ -360,7 +360,7 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-2-90b-instruct-v1:0", "meta.llama3-2-90b-instruct-v1:0",
] ]
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[ BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
"cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21" "cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21", "nova"
] ]
####### COMPLETION MODELS ################### ####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = [] open_ai_chat_completion_models: List = []
@ -863,6 +863,9 @@ from .llms.bedrock.common_utils import (
from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import ( from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import (
AmazonAI21Config, AmazonAI21Config,
) )
from .llms.bedrock.chat.invoke_transformations.amazon_nova_transformation import (
AmazonInvokeNovaConfig,
)
from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import ( from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import (
AmazonAnthropicConfig, AmazonAnthropicConfig,
) )

View file

@ -0,0 +1,36 @@
"""
Base class for Additional Logging Utils for CustomLoggers
- Health Check for the logging util
- Get Request / Response Payload for the logging util
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
class AdditionalLoggingUtils(ABC):
def __init__(self):
super().__init__()
@abstractmethod
async def async_health_check(self) -> IntegrationHealthCheckStatus:
"""
Check if the service is healthy
"""
pass
@abstractmethod
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
"""
return None

View file

@ -1,19 +0,0 @@
"""
Base class for health check integrations
"""
from abc import ABC, abstractmethod
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
class HealthCheckIntegration(ABC):
def __init__(self):
super().__init__()
@abstractmethod
async def async_health_check(self) -> IntegrationHealthCheckStatus:
"""
Check if the service is healthy
"""
pass

View file

@ -38,14 +38,14 @@ from litellm.types.integrations.datadog import *
from litellm.types.services import ServiceLoggerPayload from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from ..base_health_check import HealthCheckIntegration from ..additional_logging_utils import AdditionalLoggingUtils
DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept
class DataDogLogger( class DataDogLogger(
CustomBatchLogger, CustomBatchLogger,
HealthCheckIntegration, AdditionalLoggingUtils,
): ):
# Class variables or attributes # Class variables or attributes
def __init__( def __init__(
@ -543,3 +543,13 @@ class DataDogLogger(
status="unhealthy", status="unhealthy",
error_message=str(e), error_message=str(e),
) )
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetimeObj],
end_time_utc: Optional[datetimeObj],
) -> Optional[dict]:
raise NotImplementedError(
"Datdog Integration for getting request/response payloads not implemented as yet"
)

View file

@ -1,12 +1,16 @@
import asyncio import asyncio
import json
import os import os
import uuid import uuid
from datetime import datetime from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from urllib.parse import quote
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.proxy._types import CommonProxyErrors from litellm.proxy._types import CommonProxyErrors
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.gcs_bucket import * from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
@ -20,7 +24,7 @@ GCS_DEFAULT_BATCH_SIZE = 2048
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20 GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
class GCSBucketLogger(GCSBucketBase): class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
def __init__(self, bucket_name: Optional[str] = None) -> None: def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
@ -39,6 +43,7 @@ class GCSBucketLogger(GCSBucketBase):
batch_size=self.batch_size, batch_size=self.batch_size,
flush_interval=self.flush_interval, flush_interval=self.flush_interval,
) )
AdditionalLoggingUtils.__init__(self)
if premium_user is not True: if premium_user is not True:
raise ValueError( raise ValueError(
@ -150,11 +155,16 @@ class GCSBucketLogger(GCSBucketBase):
""" """
Get the object name to use for the current payload Get the object name to use for the current payload
""" """
current_date = datetime.now().strftime("%Y-%m-%d") current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
if logging_payload.get("error_str", None) is not None: if logging_payload.get("error_str", None) is not None:
object_name = f"{current_date}/failure-{uuid.uuid4().hex}" object_name = self._generate_failure_object_name(
request_date_str=current_date,
)
else: else:
object_name = f"{current_date}/{response_obj.get('id', '')}" object_name = self._generate_success_object_name(
request_date_str=current_date,
response_id=response_obj.get("id", ""),
)
# used for testing # used for testing
_litellm_params = kwargs.get("litellm_params", None) or {} _litellm_params = kwargs.get("litellm_params", None) or {}
@ -163,3 +173,65 @@ class GCSBucketLogger(GCSBucketBase):
object_name = _metadata["gcs_log_id"] object_name = _metadata["gcs_log_id"]
return object_name return object_name
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
Tries current day, next day, and previous day until it finds the payload
"""
if start_time_utc is None:
raise ValueError(
"start_time_utc is required for getting a payload from GCS Bucket"
)
# Try current day, next day, and previous day
dates_to_try = [
start_time_utc,
start_time_utc + timedelta(days=1),
start_time_utc - timedelta(days=1),
]
date_str = None
for date in dates_to_try:
try:
date_str = self._get_object_date_from_datetime(datetime_obj=date)
object_name = self._generate_success_object_name(
request_date_str=date_str,
response_id=request_id,
)
encoded_object_name = quote(object_name, safe="")
response = await self.download_gcs_object(encoded_object_name)
if response is not None:
loaded_response = json.loads(response)
return loaded_response
except Exception as e:
verbose_logger.debug(
f"Failed to fetch payload for date {date_str}: {str(e)}"
)
continue
return None
def _generate_success_object_name(
self,
request_date_str: str,
response_id: str,
) -> str:
return f"{request_date_str}/{response_id}"
def _generate_failure_object_name(
self,
request_date_str: str,
) -> str:
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
return datetime_obj.strftime("%Y-%m-%d")
async def async_health_check(self) -> IntegrationHealthCheckStatus:
raise NotImplementedError("GCS Bucket does not support health check")

View file

@ -3,7 +3,8 @@
import copy import copy
import os import os
import traceback import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from packaging.version import Version from packaging.version import Version
@ -13,9 +14,16 @@ from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
from litellm.llms.custom_httpx.http_handler import _get_httpx_client from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.secret_managers.main import str_to_bool from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.langfuse import * from litellm.types.integrations.langfuse import *
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.utils import ( from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
RerankResponse,
StandardLoggingPayload, StandardLoggingPayload,
StandardLoggingPromptManagementMetadata, StandardLoggingPromptManagementMetadata,
TextCompletionResponse,
TranscriptionResponse,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -150,19 +158,29 @@ class LangFuseLogger:
return metadata return metadata
def _old_log_event( # noqa: PLR0915 def log_event_on_langfuse(
self, self,
kwargs, kwargs: dict,
response_obj, response_obj: Union[
start_time, None,
end_time, dict,
user_id, EmbeddingResponse,
print_verbose, ModelResponse,
level="DEFAULT", TextCompletionResponse,
status_message=None, ImageResponse,
TranscriptionResponse,
RerankResponse,
HttpxBinaryResponseContent,
],
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
user_id: Optional[str] = None,
level: str = "DEFAULT",
status_message: Optional[str] = None,
) -> dict: ) -> dict:
# Method definition """
Logs a success or error event on Langfuse
"""
try: try:
verbose_logger.debug( verbose_logger.debug(
f"Langfuse Logging - Enters logging function for model {kwargs}" f"Langfuse Logging - Enters logging function for model {kwargs}"
@ -198,66 +216,13 @@ class LangFuseLogger:
# if casting value to str fails don't block logging # if casting value to str fails don't block logging
pass pass
# end of processing langfuse ######################## input, output = self._get_langfuse_input_output_content(
if ( kwargs=kwargs,
level == "ERROR" response_obj=response_obj,
and status_message is not None prompt=prompt,
and isinstance(status_message, str) level=level,
): status_message=status_message,
input = prompt )
output = status_message
elif response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
input = prompt
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt
output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.HttpxBinaryResponseContent
):
input = prompt
output = "speech-output"
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = response_obj.choices[0].text
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj["data"]
elif response_obj is not None and isinstance(
response_obj, litellm.TranscriptionResponse
):
input = prompt
output = response_obj["text"]
elif response_obj is not None and isinstance(
response_obj, litellm.RerankResponse
):
input = prompt
output = response_obj.results
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "_arealtime"
and response_obj is not None
and isinstance(response_obj, list)
):
input = kwargs.get("input")
output = response_obj
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint"
and response_obj is not None
and isinstance(response_obj, dict)
):
input = prompt
output = response_obj.get("response", "")
verbose_logger.debug( verbose_logger.debug(
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}" f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}"
) )
@ -265,31 +230,30 @@ class LangFuseLogger:
generation_id = None generation_id = None
if self._is_langfuse_v2(): if self._is_langfuse_v2():
trace_id, generation_id = self._log_langfuse_v2( trace_id, generation_id = self._log_langfuse_v2(
user_id, user_id=user_id,
metadata, metadata=metadata,
litellm_params, litellm_params=litellm_params,
output, output=output,
start_time, start_time=start_time,
end_time, end_time=end_time,
kwargs, kwargs=kwargs,
optional_params, optional_params=optional_params,
input, input=input,
response_obj, response_obj=response_obj,
level, level=level,
print_verbose, litellm_call_id=litellm_call_id,
litellm_call_id,
) )
elif response_obj is not None: elif response_obj is not None:
self._log_langfuse_v1( self._log_langfuse_v1(
user_id, user_id=user_id,
metadata, metadata=metadata,
output, output=output,
start_time, start_time=start_time,
end_time, end_time=end_time,
kwargs, kwargs=kwargs,
optional_params, optional_params=optional_params,
input, input=input,
response_obj, response_obj=response_obj,
) )
verbose_logger.debug( verbose_logger.debug(
f"Langfuse Layer Logging - final response object: {response_obj}" f"Langfuse Layer Logging - final response object: {response_obj}"
@ -303,11 +267,108 @@ class LangFuseLogger:
) )
return {"trace_id": None, "generation_id": None} return {"trace_id": None, "generation_id": None}
def _get_langfuse_input_output_content(
self,
kwargs: dict,
response_obj: Union[
None,
dict,
EmbeddingResponse,
ModelResponse,
TextCompletionResponse,
ImageResponse,
TranscriptionResponse,
RerankResponse,
HttpxBinaryResponseContent,
],
prompt: dict,
level: str,
status_message: Optional[str],
) -> Tuple[Optional[dict], Optional[Union[str, dict, list]]]:
"""
Get the input and output content for Langfuse logging
Args:
kwargs: The keyword arguments passed to the function
response_obj: The response object returned by the function
prompt: The prompt used to generate the response
level: The level of the log message
status_message: The status message of the log message
Returns:
input: The input content for Langfuse logging
output: The output content for Langfuse logging
"""
input = None
output: Optional[Union[str, dict, List[Any]]] = None
if (
level == "ERROR"
and status_message is not None
and isinstance(status_message, str)
):
input = prompt
output = status_message
elif response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
input = prompt
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt
output = self._get_chat_content_for_langfuse(response_obj)
elif response_obj is not None and isinstance(
response_obj, litellm.HttpxBinaryResponseContent
):
input = prompt
output = "speech-output"
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = self._get_text_completion_content_for_langfuse(response_obj)
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj.get("data", None)
elif response_obj is not None and isinstance(
response_obj, litellm.TranscriptionResponse
):
input = prompt
output = response_obj.get("text", None)
elif response_obj is not None and isinstance(
response_obj, litellm.RerankResponse
):
input = prompt
output = response_obj.results
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "_arealtime"
and response_obj is not None
and isinstance(response_obj, list)
):
input = kwargs.get("input")
output = response_obj
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint"
and response_obj is not None
and isinstance(response_obj, dict)
):
input = prompt
output = response_obj.get("response", "")
return input, output
async def _async_log_event( async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose self, kwargs, response_obj, start_time, end_time, user_id
): ):
""" """
TODO: support async calls when langfuse is truly async Langfuse SDK uses a background thread to log events
This approach does not impact latency and runs in the background
""" """
def _is_langfuse_v2(self): def _is_langfuse_v2(self):
@ -361,19 +422,18 @@ class LangFuseLogger:
def _log_langfuse_v2( # noqa: PLR0915 def _log_langfuse_v2( # noqa: PLR0915
self, self,
user_id, user_id: Optional[str],
metadata, metadata: dict,
litellm_params, litellm_params: dict,
output, output: Optional[Union[str, dict, list]],
start_time, start_time: Optional[datetime],
end_time, end_time: Optional[datetime],
kwargs, kwargs: dict,
optional_params, optional_params: dict,
input, input: Optional[dict],
response_obj, response_obj,
level, level: str,
print_verbose, litellm_call_id: Optional[str],
litellm_call_id,
) -> tuple: ) -> tuple:
verbose_logger.debug("Langfuse Layer Logging - logging to langfuse v2") verbose_logger.debug("Langfuse Layer Logging - logging to langfuse v2")
@ -657,6 +717,31 @@ class LangFuseLogger:
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}") verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
return None, None return None, None
@staticmethod
def _get_chat_content_for_langfuse(
response_obj: ModelResponse,
):
"""
Get the chat content for Langfuse logging
"""
if response_obj.choices and len(response_obj.choices) > 0:
output = response_obj["choices"][0]["message"].json()
return output
else:
return None
@staticmethod
def _get_text_completion_content_for_langfuse(
response_obj: TextCompletionResponse,
):
"""
Get the text completion content for Langfuse logging
"""
if response_obj.choices and len(response_obj.choices) > 0:
return response_obj.choices[0].text
else:
return None
@staticmethod @staticmethod
def _get_langfuse_tags( def _get_langfuse_tags(
standard_logging_object: Optional[StandardLoggingPayload], standard_logging_object: Optional[StandardLoggingPayload],

View file

@ -247,13 +247,12 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
standard_callback_dynamic_params=standard_callback_dynamic_params, standard_callback_dynamic_params=standard_callback_dynamic_params,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
) )
langfuse_logger_to_use._old_log_event( langfuse_logger_to_use.log_event_on_langfuse(
kwargs=kwargs, kwargs=kwargs,
response_obj=response_obj, response_obj=response_obj,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
user_id=kwargs.get("user", None), user_id=kwargs.get("user", None),
print_verbose=None,
) )
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
@ -271,12 +270,11 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
) )
if standard_logging_object is None: if standard_logging_object is None:
return return
langfuse_logger_to_use._old_log_event( langfuse_logger_to_use.log_event_on_langfuse(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
response_obj=None, response_obj=None,
user_id=kwargs.get("user", None), user_id=kwargs.get("user", None),
print_verbose=None,
status_message=standard_logging_object["error_str"], status_message=standard_logging_object["error_str"],
level="ERROR", level="ERROR",
kwargs=kwargs, kwargs=kwargs,

View file

@ -118,6 +118,7 @@ class PagerDutyAlerting(SlackAlerting):
user_api_key_user_id=_meta.get("user_api_key_user_id"), user_api_key_user_id=_meta.get("user_api_key_user_id"),
user_api_key_team_alias=_meta.get("user_api_key_team_alias"), user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"), user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
user_api_key_user_email=_meta.get("user_api_key_user_email"),
) )
) )
@ -195,6 +196,7 @@ class PagerDutyAlerting(SlackAlerting):
user_api_key_user_id=user_api_key_dict.user_id, user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_alias=user_api_key_dict.team_alias, user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id, user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_user_email=user_api_key_dict.user_email,
) )
) )

View file

@ -423,6 +423,7 @@ class PrometheusLogger(CustomLogger):
team=user_api_team, team=user_api_team,
team_alias=user_api_team_alias, team_alias=user_api_team_alias,
user=user_id, user=user_id,
user_email=standard_logging_payload["metadata"]["user_api_key_user_email"],
status_code="200", status_code="200",
model=model, model=model,
litellm_model_name=model, litellm_model_name=model,
@ -806,6 +807,7 @@ class PrometheusLogger(CustomLogger):
enum_values = UserAPIKeyLabelValues( enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id, end_user=user_api_key_dict.end_user_id,
user=user_api_key_dict.user_id, user=user_api_key_dict.user_id,
user_email=user_api_key_dict.user_email,
hashed_api_key=user_api_key_dict.api_key, hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias, api_key_alias=user_api_key_dict.key_alias,
team=user_api_key_dict.team_id, team=user_api_key_dict.team_id,
@ -853,6 +855,7 @@ class PrometheusLogger(CustomLogger):
team=user_api_key_dict.team_id, team=user_api_key_dict.team_id,
team_alias=user_api_key_dict.team_alias, team_alias=user_api_key_dict.team_alias,
user=user_api_key_dict.user_id, user=user_api_key_dict.user_id,
user_email=user_api_key_dict.user_email,
status_code="200", status_code="200",
) )
_labels = prometheus_label_factory( _labels = prometheus_label_factory(

View file

@ -199,6 +199,7 @@ class Logging(LiteLLMLoggingBaseClass):
dynamic_async_failure_callbacks: Optional[ dynamic_async_failure_callbacks: Optional[
List[Union[str, Callable, CustomLogger]] List[Union[str, Callable, CustomLogger]]
] = None, ] = None,
applied_guardrails: Optional[List[str]] = None,
kwargs: Optional[Dict] = None, kwargs: Optional[Dict] = None,
): ):
_input: Optional[str] = messages # save original value of messages _input: Optional[str] = messages # save original value of messages
@ -271,6 +272,7 @@ class Logging(LiteLLMLoggingBaseClass):
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"input": _input, "input": _input,
"litellm_params": litellm_params, "litellm_params": litellm_params,
"applied_guardrails": applied_guardrails,
} }
def process_dynamic_callbacks(self): def process_dynamic_callbacks(self):
@ -1247,13 +1249,12 @@ class Logging(LiteLLMLoggingBaseClass):
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
) )
if langfuse_logger_to_use is not None: if langfuse_logger_to_use is not None:
_response = langfuse_logger_to_use._old_log_event( _response = langfuse_logger_to_use.log_event_on_langfuse(
kwargs=kwargs, kwargs=kwargs,
response_obj=result, response_obj=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
user_id=kwargs.get("user", None), user_id=kwargs.get("user", None),
print_verbose=print_verbose,
) )
if _response is not None and isinstance(_response, dict): if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None) _trace_id = _response.get("trace_id", None)
@ -1957,12 +1958,11 @@ class Logging(LiteLLMLoggingBaseClass):
standard_callback_dynamic_params=self.standard_callback_dynamic_params, standard_callback_dynamic_params=self.standard_callback_dynamic_params,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
) )
_response = langfuse_logger_to_use._old_log_event( _response = langfuse_logger_to_use.log_event_on_langfuse(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
response_obj=None, response_obj=None,
user_id=kwargs.get("user", None), user_id=kwargs.get("user", None),
print_verbose=print_verbose,
status_message=str(exception), status_message=str(exception),
level="ERROR", level="ERROR",
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -2854,6 +2854,7 @@ class StandardLoggingPayloadSetup:
metadata: Optional[Dict[str, Any]], metadata: Optional[Dict[str, Any]],
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None, prompt_integration: Optional[str] = None,
applied_guardrails: Optional[List[str]] = None,
) -> StandardLoggingMetadata: ) -> StandardLoggingMetadata:
""" """
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -2868,6 +2869,7 @@ class StandardLoggingPayloadSetup:
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
""" """
prompt_management_metadata: Optional[ prompt_management_metadata: Optional[
StandardLoggingPromptManagementMetadata StandardLoggingPromptManagementMetadata
] = None ] = None
@ -2892,11 +2894,13 @@ class StandardLoggingPayloadSetup:
user_api_key_org_id=None, user_api_key_org_id=None,
user_api_key_user_id=None, user_api_key_user_id=None,
user_api_key_team_alias=None, user_api_key_team_alias=None,
user_api_key_user_email=None,
spend_logs_metadata=None, spend_logs_metadata=None,
requester_ip_address=None, requester_ip_address=None,
requester_metadata=None, requester_metadata=None,
user_api_key_end_user_id=None, user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata, prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
) )
if isinstance(metadata, dict): if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys # Filter the metadata dictionary to include only the specified keys
@ -3195,6 +3199,7 @@ def get_standard_logging_object_payload(
metadata=metadata, metadata=metadata,
litellm_params=litellm_params, litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None), prompt_integration=kwargs.get("prompt_integration", None),
applied_guardrails=kwargs.get("applied_guardrails", None),
) )
_request_body = proxy_server_request.get("body", {}) _request_body = proxy_server_request.get("body", {})
@ -3324,12 +3329,14 @@ def get_standard_logging_metadata(
user_api_key_team_id=None, user_api_key_team_id=None,
user_api_key_org_id=None, user_api_key_org_id=None,
user_api_key_user_id=None, user_api_key_user_id=None,
user_api_key_user_email=None,
user_api_key_team_alias=None, user_api_key_team_alias=None,
spend_logs_metadata=None, spend_logs_metadata=None,
requester_ip_address=None, requester_ip_address=None,
requester_metadata=None, requester_metadata=None,
user_api_key_end_user_id=None, user_api_key_end_user_id=None,
prompt_management_metadata=None, prompt_management_metadata=None,
applied_guardrails=None,
) )
if isinstance(metadata, dict): if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys # Filter the metadata dictionary to include only the specified keys

View file

@ -3,7 +3,8 @@ import json
import time import time
import traceback import traceback
import uuid import uuid
from typing import Dict, Iterable, List, Literal, Optional, Union import re
from typing import Dict, Iterable, List, Literal, Optional, Union, Tuple
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -220,6 +221,16 @@ def _handle_invalid_parallel_tool_calls(
# if there is a JSONDecodeError, return the original tool_calls # if there is a JSONDecodeError, return the original tool_calls
return tool_calls return tool_calls
def _parse_content_for_reasoning(message_text: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
if not message_text:
return None, None
reasoning_match = re.match(r"<think>(.*?)</think>(.*)", message_text, re.DOTALL)
if reasoning_match:
return reasoning_match.group(1), reasoning_match.group(2)
return None, message_text
class LiteLLMResponseObjectHandler: class LiteLLMResponseObjectHandler:
@ -432,8 +443,14 @@ def convert_to_model_response_object( # noqa: PLR0915
for field in choice["message"].keys(): for field in choice["message"].keys():
if field not in message_keys: if field not in message_keys:
provider_specific_fields[field] = choice["message"][field] provider_specific_fields[field] = choice["message"][field]
# Handle reasoning models that display `reasoning_content` within `content`
reasoning_content, content = _parse_content_for_reasoning(choice["message"].get("content", None))
if reasoning_content:
provider_specific_fields["reasoning_content"] = reasoning_content
message = Message( message = Message(
content=choice["message"].get("content", None), content=content,
role=choice["message"]["role"] or "assistant", role=choice["message"]["role"] or "assistant",
function_call=choice["message"].get("function_call", None), function_call=choice["message"].get("function_call", None),
tool_calls=tool_calls, tool_calls=tool_calls,

View file

@ -1,7 +1,8 @@
from typing import Callable, List, Union from typing import Callable, List, Set, Union
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -85,6 +86,21 @@ class LoggingCallbackManager:
callback=callback, parent_list=litellm._async_failure_callback callback=callback, parent_list=litellm._async_failure_callback
) )
def remove_callback_from_list_by_object(
self, callback_list, obj
):
"""
Remove callbacks that are methods of a particular object (e.g., router cleanup)
"""
if not isinstance(callback_list, list): # Not list -> do nothing
return
remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj]
for c in remove_list:
callback_list.remove(c)
def _add_string_callback_to_list( def _add_string_callback_to_list(
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]] self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
): ):
@ -205,3 +221,36 @@ class LoggingCallbackManager:
litellm._async_success_callback = [] litellm._async_success_callback = []
litellm._async_failure_callback = [] litellm._async_failure_callback = []
litellm.callbacks = [] litellm.callbacks = []
def _get_all_callbacks(self) -> List[Union[CustomLogger, Callable, str]]:
"""
Get all callbacks from litellm.callbacks, litellm.success_callback, litellm.failure_callback, litellm._async_success_callback, litellm._async_failure_callback
"""
return (
litellm.callbacks
+ litellm.success_callback
+ litellm.failure_callback
+ litellm._async_success_callback
+ litellm._async_failure_callback
)
def get_active_additional_logging_utils_from_custom_logger(
self,
) -> Set[AdditionalLoggingUtils]:
"""
Get all custom loggers that are instances of the given class type
Args:
class_type: The class type to match against (e.g., AdditionalLoggingUtils)
Returns:
Set[CustomLogger]: Set of custom loggers that are instances of the given class type
"""
all_callbacks = self._get_all_callbacks()
matched_callbacks: Set[AdditionalLoggingUtils] = set()
for callback in all_callbacks:
if isinstance(callback, CustomLogger) and isinstance(
callback, AdditionalLoggingUtils
):
matched_callbacks.add(callback)
return matched_callbacks

View file

@ -1421,6 +1421,8 @@ def anthropic_messages_pt( # noqa: PLR0915
) )
user_content.append(_content_element) user_content.append(_content_element)
elif m.get("type", "") == "document":
user_content.append(cast(AnthropicMessagesDocumentParam, m))
elif isinstance(user_message_types_block["content"], str): elif isinstance(user_message_types_block["content"], str):
_anthropic_content_text_element: AnthropicMessagesTextParam = { _anthropic_content_text_element: AnthropicMessagesTextParam = {
"type": "text", "type": "text",

View file

@ -0,0 +1,81 @@
from typing import Any, Dict, Optional, Set
class SensitiveDataMasker:
def __init__(
self,
sensitive_patterns: Optional[Set[str]] = None,
visible_prefix: int = 4,
visible_suffix: int = 4,
mask_char: str = "*",
):
self.sensitive_patterns = sensitive_patterns or {
"password",
"secret",
"key",
"token",
"auth",
"credential",
"access",
"private",
"certificate",
}
self.visible_prefix = visible_prefix
self.visible_suffix = visible_suffix
self.mask_char = mask_char
def _mask_value(self, value: str) -> str:
if not value or len(str(value)) < (self.visible_prefix + self.visible_suffix):
return value
value_str = str(value)
masked_length = len(value_str) - (self.visible_prefix + self.visible_suffix)
return f"{value_str[:self.visible_prefix]}{self.mask_char * masked_length}{value_str[-self.visible_suffix:]}"
def is_sensitive_key(self, key: str) -> bool:
key_lower = str(key).lower()
result = any(pattern in key_lower for pattern in self.sensitive_patterns)
return result
def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10
) -> Dict[str, Any]:
if depth >= max_depth:
return data
masked_data: Dict[str, Any] = {}
for k, v in data.items():
try:
if isinstance(v, dict):
masked_data[k] = self.mask_dict(v, depth + 1)
elif hasattr(v, "__dict__") and not isinstance(v, type):
masked_data[k] = self.mask_dict(vars(v), depth + 1)
elif self.is_sensitive_key(k):
str_value = str(v) if v is not None else ""
masked_data[k] = self._mask_value(str_value)
else:
masked_data[k] = (
v if isinstance(v, (int, float, bool, str)) else str(v)
)
except Exception:
masked_data[k] = "<unable to serialize>"
return masked_data
# Usage example:
"""
masker = SensitiveDataMasker()
data = {
"api_key": "sk-1234567890abcdef",
"redis_password": "very_secret_pass",
"port": 6379
}
masked = masker.mask_dict(data)
# Result: {
# "api_key": "sk-1****cdef",
# "redis_password": "very****pass",
# "port": 6379
# }
"""

View file

@ -809,7 +809,10 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False: if self.sent_first_chunk is False:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
if response_obj.get("provider_specific_fields") is not None:
completion_obj["provider_specific_fields"] = response_obj[
"provider_specific_fields"
]
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
_index: Optional[int] = completion_obj.get("index") _index: Optional[int] = completion_obj.get("index")
if _index is not None: if _index is not None:

View file

@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
import copy import copy
import json import json
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
@ -506,6 +506,29 @@ class ModelResponseIterator:
return usage_block return usage_block
def _content_block_delta_helper(self, chunk: dict):
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
provider_specific_fields = {}
content_block = ContentBlockDelta(**chunk) # type: ignore
self.content_blocks.append(content_block)
if "text" in content_block["delta"]:
text = content_block["delta"]["text"]
elif "partial_json" in content_block["delta"]:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": content_block["delta"]["partial_json"],
},
"index": self.tool_index,
}
elif "citation" in content_block["delta"]:
provider_specific_fields["citation"] = content_block["delta"]["citation"]
return text, tool_use, provider_specific_fields
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try: try:
type_chunk = chunk.get("type", "") or "" type_chunk = chunk.get("type", "") or ""
@ -515,6 +538,7 @@ class ModelResponseIterator:
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields: Dict[str, Any] = {}
index = int(chunk.get("index", 0)) index = int(chunk.get("index", 0))
if type_chunk == "content_block_delta": if type_chunk == "content_block_delta":
@ -522,20 +546,9 @@ class ModelResponseIterator:
Anthropic content chunk Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}} chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
""" """
content_block = ContentBlockDelta(**chunk) # type: ignore text, tool_use, provider_specific_fields = (
self.content_blocks.append(content_block) self._content_block_delta_helper(chunk=chunk)
if "text" in content_block["delta"]: )
text = content_block["delta"]["text"]
elif "partial_json" in content_block["delta"]:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": content_block["delta"]["partial_json"],
},
"index": self.tool_index,
}
elif type_chunk == "content_block_start": elif type_chunk == "content_block_start":
""" """
event: content_block_start event: content_block_start
@ -628,6 +641,9 @@ class ModelResponseIterator:
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage, usage=usage,
index=index, index=index,
provider_specific_fields=(
provider_specific_fields if provider_specific_fields else None
),
) )
return returned_chunk return returned_chunk

View file

@ -70,7 +70,7 @@ class AnthropicConfig(BaseConfig):
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
system: Optional[str] = None, system: Optional[str] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@ -628,6 +628,7 @@ class AnthropicConfig(BaseConfig):
) )
else: else:
text_content = "" text_content = ""
citations: List[Any] = []
tool_calls: List[ChatCompletionToolCallChunk] = [] tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]): for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text": if content["type"] == "text":
@ -645,10 +646,14 @@ class AnthropicConfig(BaseConfig):
index=idx, index=idx,
) )
) )
## CITATIONS
if content.get("citations", None) is not None:
citations.append(content["citations"])
_message = litellm.Message( _message = litellm.Message(
tool_calls=tool_calls, tool_calls=tool_calls,
content=text_content or None, content=text_content or None,
provider_specific_fields={"citations": citations},
) )
## HANDLE JSON MODE - anthropic returns single function call ## HANDLE JSON MODE - anthropic returns single function call

View file

@ -72,7 +72,7 @@ class AnthropicTextConfig(BaseConfig):
top_k: Optional[int] = None, top_k: Optional[int] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -5,10 +5,11 @@ import time
from typing import Any, Callable, Dict, List, Literal, Optional, Union from typing import Any, Callable, Dict, List, Literal, Optional, Union
import httpx # type: ignore import httpx # type: ignore
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
import litellm import litellm
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.constants import DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
@ -98,14 +99,6 @@ class AzureOpenAIAssistantsAPIConfig:
def select_azure_base_url_or_endpoint(azure_client_params: dict): def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
# "azure_endpoint": api_base,
# "azure_deployment": model,
# "http_client": litellm.client_session,
# "max_retries": max_retries,
# "timeout": timeout,
# }
azure_endpoint = azure_client_params.get("azure_endpoint", None) azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None: if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192 # see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
@ -312,6 +305,7 @@ class AzureChatCompletion(BaseLLM):
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True - call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default - call chat.completions.create by default
""" """
start_time = time.time()
try: try:
raw_response = await azure_client.chat.completions.with_raw_response.create( raw_response = await azure_client.chat.completions.with_raw_response.create(
**data, timeout=timeout **data, timeout=timeout
@ -320,6 +314,11 @@ class AzureChatCompletion(BaseLLM):
headers = dict(raw_response.headers) headers = dict(raw_response.headers)
response = raw_response.parse() response = raw_response.parse()
return headers, response return headers, response
except APITimeoutError as e:
end_time = time.time()
time_delta = round(end_time - start_time, 2)
e.message += f" - timeout value={timeout}, time taken={time_delta} seconds"
raise e
except Exception as e: except Exception as e:
raise e raise e
@ -353,7 +352,9 @@ class AzureChatCompletion(BaseLLM):
status_code=422, message="Missing model or messages" status_code=422, message="Missing model or messages"
) )
max_retries = optional_params.pop("max_retries", 2) max_retries = optional_params.pop("max_retries", None)
if max_retries is None:
max_retries = DEFAULT_MAX_RETRIES
json_mode: Optional[bool] = optional_params.pop("json_mode", False) json_mode: Optional[bool] = optional_params.pop("json_mode", False)
### CHECK IF CLOUDFLARE AI GATEWAY ### ### CHECK IF CLOUDFLARE AI GATEWAY ###
@ -415,6 +416,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider=azure_ad_token_provider, azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout, timeout=timeout,
client=client, client=client,
max_retries=max_retries,
) )
else: else:
return self.acompletion( return self.acompletion(
@ -430,6 +432,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
logging_obj=logging_obj, logging_obj=logging_obj,
max_retries=max_retries,
convert_tool_call_to_json_mode=json_mode, convert_tool_call_to_json_mode=json_mode,
) )
elif "stream" in optional_params and optional_params["stream"] is True: elif "stream" in optional_params and optional_params["stream"] is True:
@ -445,6 +448,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider=azure_ad_token_provider, azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout, timeout=timeout,
client=client, client=client,
max_retries=max_retries,
) )
else: else:
## LOGGING ## LOGGING
@ -553,6 +557,7 @@ class AzureChatCompletion(BaseLLM):
dynamic_params: bool, dynamic_params: bool,
model_response: ModelResponse, model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
max_retries: int,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None, azure_ad_token_provider: Optional[Callable] = None,
convert_tool_call_to_json_mode: Optional[bool] = None, convert_tool_call_to_json_mode: Optional[bool] = None,
@ -560,12 +565,6 @@ class AzureChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -649,6 +648,7 @@ class AzureChatCompletion(BaseLLM):
) )
raise AzureOpenAIError(status_code=500, message=str(e)) raise AzureOpenAIError(status_code=500, message=str(e))
except Exception as e: except Exception as e:
message = getattr(e, "message", str(e))
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=data["messages"], input=data["messages"],
@ -659,7 +659,7 @@ class AzureChatCompletion(BaseLLM):
if hasattr(e, "status_code"): if hasattr(e, "status_code"):
raise e raise e
else: else:
raise AzureOpenAIError(status_code=500, message=str(e)) raise AzureOpenAIError(status_code=500, message=message)
def streaming( def streaming(
self, self,
@ -671,15 +671,11 @@ class AzureChatCompletion(BaseLLM):
data: dict, data: dict,
model: str, model: str,
timeout: Any, timeout: Any,
max_retries: int,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None, azure_ad_token_provider: Optional[Callable] = None,
client=None, client=None,
): ):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -742,6 +738,7 @@ class AzureChatCompletion(BaseLLM):
data: dict, data: dict,
model: str, model: str,
timeout: Any, timeout: Any,
max_retries: int,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None, azure_ad_token_provider: Optional[Callable] = None,
client=None, client=None,
@ -753,7 +750,7 @@ class AzureChatCompletion(BaseLLM):
"azure_endpoint": api_base, "azure_endpoint": api_base,
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.aclient_session, "http_client": litellm.aclient_session,
"max_retries": data.pop("max_retries", 2), "max_retries": max_retries,
"timeout": timeout, "timeout": timeout,
} }
azure_client_params = select_azure_base_url_or_endpoint( azure_client_params = select_azure_base_url_or_endpoint(
@ -807,10 +804,11 @@ class AzureChatCompletion(BaseLLM):
status_code = getattr(e, "status_code", 500) status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None) error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None) error_response = getattr(e, "response", None)
message = getattr(e, "message", str(e))
if error_headers is None and error_response: if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None) error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError( raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers status_code=status_code, message=message, headers=error_headers
) )
async def aembedding( async def aembedding(

View file

@ -113,6 +113,17 @@ class AzureOpenAIConfig(BaseConfig):
return False return False
def _is_response_format_supported_api_version(
self, api_version_year: str, api_version_month: str
) -> bool:
"""
- check if api_version is supported for response_format
"""
is_supported = int(api_version_year) <= 2024 and int(api_version_month) >= 8
return is_supported
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -171,13 +182,20 @@ class AzureOpenAIConfig(BaseConfig):
_is_response_format_supported_model = ( _is_response_format_supported_model = (
self._is_response_format_supported_model(model) self._is_response_format_supported_model(model)
) )
should_convert_response_format_to_tool = (
api_version_year <= "2024" and api_version_month < "08" is_response_format_supported_api_version = (
) or not _is_response_format_supported_model self._is_response_format_supported_api_version(
api_version_year, api_version_month
)
)
is_response_format_supported = (
is_response_format_supported_api_version
and _is_response_format_supported_model
)
optional_params = self._add_response_format_to_tools( optional_params = self._add_response_format_to_tools(
optional_params=optional_params, optional_params=optional_params,
value=value, value=value,
should_convert_response_format_to_tool=should_convert_response_format_to_tool, is_response_format_supported=is_response_format_supported,
) )
elif param == "tools" and isinstance(value, list): elif param == "tools" and isinstance(value, list):
optional_params.setdefault("tools", []) optional_params.setdefault("tools", [])

View file

@ -131,6 +131,7 @@ class AzureTextCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
logging_obj=logging_obj, logging_obj=logging_obj,
max_retries=max_retries,
) )
elif "stream" in optional_params and optional_params["stream"] is True: elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming( return self.streaming(
@ -236,17 +237,12 @@ class AzureTextCompletion(BaseLLM):
timeout: Any, timeout: Any,
model_response: ModelResponse, model_response: ModelResponse,
logging_obj: Any, logging_obj: Any,
max_retries: int,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI client=None, # this is the AsyncAzureOpenAI
): ):
response = None response = None
try: try:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,

View file

@ -34,6 +34,17 @@ class BaseLLMModelInfo(ABC):
def get_api_base(api_base: Optional[str] = None) -> Optional[str]: def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass pass
@staticmethod
@abstractmethod
def get_base_model(model: str) -> Optional[str]:
"""
Returns the base model name from the given model name.
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
This function will return `anthropic.claude-3-opus-20240229-v1:0`
"""
pass
def _dict_to_response_format_helper( def _dict_to_response_format_helper(
response_format: dict, ref_template: Optional[str] = None response_format: dict, ref_template: Optional[str] = None

View file

@ -20,6 +20,7 @@ from pydantic import BaseModel
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionToolChoiceFunctionParam, ChatCompletionToolChoiceFunctionParam,
@ -27,9 +28,6 @@ from litellm.types.llms.openai import (
ChatCompletionToolParam, ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk, ChatCompletionToolParamFunctionChunk,
) )
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper from litellm.utils import CustomStreamWrapper
@ -163,7 +161,7 @@ class BaseConfig(ABC):
self, self,
optional_params: dict, optional_params: dict,
value: dict, value: dict,
should_convert_response_format_to_tool: bool, is_response_format_supported: bool,
) -> dict: ) -> dict:
""" """
Follow similar approach to anthropic - translate to a single tool call. Follow similar approach to anthropic - translate to a single tool call.
@ -183,7 +181,8 @@ class BaseConfig(ABC):
elif "json_schema" in value: elif "json_schema" in value:
json_schema = value["json_schema"]["schema"] json_schema = value["json_schema"]["schema"]
if json_schema and should_convert_response_format_to_tool: if json_schema and not is_response_format_supported:
_tool_choice = ChatCompletionToolChoiceObjectParam( _tool_choice = ChatCompletionToolChoiceObjectParam(
type="function", type="function",
function=ChatCompletionToolChoiceFunctionParam( function=ChatCompletionToolChoiceFunctionParam(

View file

@ -52,6 +52,7 @@ class BaseAWSLLM:
"aws_role_name", "aws_role_name",
"aws_web_identity_token", "aws_web_identity_token",
"aws_sts_endpoint", "aws_sts_endpoint",
"aws_bedrock_runtime_endpoint",
] ]
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str: def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:

View file

@ -33,14 +33,7 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ModelResponse, Usage from litellm.types.utils import ModelResponse, Usage
from litellm.utils import add_dummy_tool, has_tool_call_blocks from litellm.utils import add_dummy_tool, has_tool_call_blocks
from ..common_utils import ( from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
AmazonBedrockGlobalConfig,
BedrockError,
get_bedrock_tool_name,
)
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
class AmazonConverseConfig(BaseConfig): class AmazonConverseConfig(BaseConfig):
@ -63,7 +56,7 @@ class AmazonConverseConfig(BaseConfig):
topP: Optional[int] = None, topP: Optional[int] = None,
topK: Optional[int] = None, topK: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@ -104,7 +97,7 @@ class AmazonConverseConfig(BaseConfig):
] ]
## Filter out 'cross-region' from model name ## Filter out 'cross-region' from model name
base_model = self._get_base_model(model) base_model = BedrockModelInfo.get_base_model(model)
if ( if (
base_model.startswith("anthropic") base_model.startswith("anthropic")
@ -341,9 +334,9 @@ class AmazonConverseConfig(BaseConfig):
if "top_k" in inference_params: if "top_k" in inference_params:
inference_params["topK"] = inference_params.pop("top_k") inference_params["topK"] = inference_params.pop("top_k")
return InferenceConfig(**inference_params) return InferenceConfig(**inference_params)
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict: def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
base_model = self._get_base_model(model) base_model = BedrockModelInfo.get_base_model(model)
val_top_k = None val_top_k = None
if "topK" in inference_params: if "topK" in inference_params:
@ -352,11 +345,11 @@ class AmazonConverseConfig(BaseConfig):
val_top_k = inference_params.pop("top_k") val_top_k = inference_params.pop("top_k")
if val_top_k: if val_top_k:
if (base_model.startswith("anthropic")): if base_model.startswith("anthropic"):
return {"top_k": val_top_k} return {"top_k": val_top_k}
if base_model.startswith("amazon.nova"): if base_model.startswith("amazon.nova"):
return {'inferenceConfig': {"topK": val_top_k}} return {"inferenceConfig": {"topK": val_top_k}}
return {} return {}
def _transform_request_helper( def _transform_request_helper(
@ -393,15 +386,25 @@ class AmazonConverseConfig(BaseConfig):
) + ["top_k"] ) + ["top_k"]
supported_tool_call_params = ["tools", "tool_choice"] supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"] supported_guardrail_params = ["guardrailConfig"]
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params total_supported_params = (
supported_converse_params
+ supported_tool_call_params
+ supported_guardrail_params
)
inference_params.pop("json_mode", None) # used for handling json_schema inference_params.pop("json_mode", None) # used for handling json_schema
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params' # keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params} additional_request_params = {
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params} k: v for k, v in inference_params.items() if k not in total_supported_params
}
inference_params = {
k: v for k, v in inference_params.items() if k in total_supported_params
}
# Only set the topK value in for models that support it # Only set the topK value in for models that support it
additional_request_params.update(self._handle_top_k_value(model, inference_params)) additional_request_params.update(
self._handle_top_k_value(model, inference_params)
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", []) inference_params.pop("tools", [])
@ -679,41 +682,6 @@ class AmazonConverseConfig(BaseConfig):
return model_response return model_response
def _supported_cross_region_inference_region(self) -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]
def _get_base_model(self, model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if potential_region in self._supported_cross_region_inference_region():
return model.split(".", 1)[1]
elif (
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException: ) -> BaseLLMException:

View file

@ -40,6 +40,9 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
parse_xml_params, parse_xml_params,
prompt_factory, prompt_factory,
) )
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
@ -103,7 +106,7 @@ class AmazonCohereChatConfig:
stop_sequences: Optional[str] = None, stop_sequences: Optional[str] = None,
raw_prompting: Optional[bool] = None, raw_prompting: Optional[bool] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@ -177,6 +180,7 @@ async def make_call(
logging_obj: Logging, logging_obj: Logging,
fake_stream: bool = False, fake_stream: bool = False,
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
): ):
try: try:
if client is None: if client is None:
@ -214,6 +218,14 @@ async def make_call(
completion_stream: Any = MockResponseIterator( completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode model_response=model_response, json_mode=json_mode
) )
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=False,
)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else: else:
decoder = AWSEventStreamDecoder(model=model) decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes( completion_stream = decoder.aiter_bytes(
@ -248,6 +260,7 @@ def make_sync_call(
logging_obj: Logging, logging_obj: Logging,
fake_stream: bool = False, fake_stream: bool = False,
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
): ):
try: try:
if client is None: if client is None:
@ -283,6 +296,12 @@ def make_sync_call(
completion_stream: Any = MockResponseIterator( completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode model_response=model_response, json_mode=json_mode
) )
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=True,
)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
else: else:
decoder = AWSEventStreamDecoder(model=model) decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@ -1323,7 +1342,7 @@ class AWSEventStreamDecoder:
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True is_finished = True
finish_reason = "stop" finish_reason = "stop"
######## bedrock.anthropic mappings ############### ######## /bedrock/converse mappings ###############
elif ( elif (
"contentBlockIndex" in chunk_data "contentBlockIndex" in chunk_data
or "stopReason" in chunk_data or "stopReason" in chunk_data
@ -1331,6 +1350,11 @@ class AWSEventStreamDecoder:
or "trace" in chunk_data or "trace" in chunk_data
): ):
return self.converse_chunk_parser(chunk_data=chunk_data) return self.converse_chunk_parser(chunk_data=chunk_data)
######### /bedrock/invoke nova mappings ###############
elif "contentBlockDelta" in chunk_data:
# when using /bedrock/invoke/nova, the chunk_data is nested under "contentBlockDelta"
_chunk_data = chunk_data.get("contentBlockDelta", None)
return self.converse_chunk_parser(chunk_data=_chunk_data)
######## bedrock.mistral mappings ############### ######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data: elif "outputs" in chunk_data:
if ( if (
@ -1429,6 +1453,27 @@ class AWSEventStreamDecoder:
return chunk.decode() # type: ignore[no-any-return] return chunk.decode() # type: ignore[no-any-return]
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
def __init__(
self,
model: str,
sync_stream: bool,
) -> None:
"""
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
"""
super().__init__(model=model)
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=sync_stream,
)
def _chunk_parser(self, chunk_data: dict) -> GChunk:
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
class MockResponseIterator: # for returning ai21 streaming responses class MockResponseIterator: # for returning ai21 streaming responses
def __init__(self, model_response, json_mode: Optional[bool] = False): def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response self.model_response = model_response

View file

@ -46,7 +46,7 @@ class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
presencePenalty: Optional[dict] = None, presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None, countPenalty: Optional[dict] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -28,7 +28,7 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
temperature: Optional[float] = None, temperature: Optional[float] = None,
return_likelihood: Optional[str] = None, return_likelihood: Optional[str] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -28,7 +28,7 @@ class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
temperature: Optional[float] = None, temperature: Optional[float] = None,
topP: Optional[int] = None, topP: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -33,7 +33,7 @@ class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
top_k: Optional[float] = None, top_k: Optional[float] = None,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -0,0 +1,70 @@
"""
Handles transforming requests for `bedrock/invoke/{nova} models`
Inherits from `AmazonConverseConfig`
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
"""
from typing import List
import litellm
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
from litellm.types.llms.openai import AllMessageValues
class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
"""
Config for sending `nova` requests to `/bedrock/invoke/`
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_transformed_nova_request = super().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
**_transformed_nova_request
)
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
bedrock_invoke_nova_request = self._filter_allowed_fields(
_bedrock_invoke_nova_request
)
return bedrock_invoke_nova_request
def _filter_allowed_fields(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> dict:
"""
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
"""
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
return {
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
}
def _remove_empty_system_messages(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> None:
"""
In-place remove empty `system` messages from the request.
/bedrock/invoke/ does not allow empty `system` messages.
"""
_system_message = bedrock_invoke_nova_request.get("system", None)
if isinstance(_system_message, list) and len(_system_message) == 0:
bedrock_invoke_nova_request.pop("system", None)
return

View file

@ -33,7 +33,7 @@ class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
temperature: Optional[float] = None, temperature: Optional[float] = None,
topP: Optional[int] = None, topP: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -34,7 +34,7 @@ class AmazonAnthropicConfig:
top_p: Optional[int] = None, top_p: Optional[int] = None,
anthropic_version: Optional[str] = None, anthropic_version: Optional[str] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -1,61 +1,34 @@
import types from typing import TYPE_CHECKING, Any, List, Optional
from typing import List, Optional
import httpx
import litellm
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonAnthropicClaude3Config: class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
""" """
Reference: Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
https://docs.anthropic.com/claude/docs/models-overview#model-comparison https://docs.anthropic.com/claude/docs/models-overview#model-comparison
Supported Params for the Amazon / Anthropic Claude 3 models: Supported Params for the Amazon / Anthropic Claude 3 models:
- `max_tokens` Required (integer) max tokens. Default is 4096
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
""" """
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default anthropic_version: str = "bedrock-2023-05-31"
anthropic_version: Optional[str] = "bedrock-2023-05-31"
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__( def get_supported_openai_params(self, model: str):
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [ return [
"max_tokens", "max_tokens",
"max_completion_tokens", "max_completion_tokens",
@ -68,7 +41,13 @@ class AmazonAnthropicClaude3Config:
"extra_headers", "extra_headers",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens": if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
@ -83,3 +62,53 @@ class AmazonAnthropicClaude3Config:
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
return optional_params return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_anthropic_request = litellm.AnthropicConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
_anthropic_request.pop("model", None)
if "anthropic_version" not in _anthropic_request:
_anthropic_request["anthropic_version"] = self.anthropic_version
return _anthropic_request
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return litellm.AnthropicConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)

View file

@ -2,22 +2,18 @@ import copy
import json import json
import time import time
import urllib.parse import urllib.parse
import uuid
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
import httpx import httpx
import litellm import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
cohere_message_pt, cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt, custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory, prompt_factory,
) )
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
@ -91,7 +87,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
optional_params=optional_params, optional_params=optional_params,
) )
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com ) # https://bedrock-runtime.{region_name}.amazonaws.com
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
@ -129,15 +125,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
## CREDENTIALS ## ## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
extra_headers = optional_params.pop("extra_headers", None) extra_headers = optional_params.get("extra_headers", None)
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.get("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None) aws_session_token = optional_params.get("aws_session_token", None)
aws_role_name = optional_params.pop("aws_role_name", None) aws_role_name = optional_params.get("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None) aws_session_name = optional_params.get("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None) aws_profile_name = optional_params.get("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
aws_region_name = self._get_aws_region_name(optional_params) aws_region_name = self._get_aws_region_name(optional_params)
credentials: Credentials = self.get_credentials( credentials: Credentials = self.get_credentials(
@ -171,7 +167,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
return dict(request.headers) return dict(request.headers)
def transform_request( # noqa: PLR0915 def transform_request(
self, self,
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
@ -194,7 +190,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
for k, v in inference_params.items() for k, v in inference_params.items()
if k not in self.aws_authentication_params if k not in self.aws_authentication_params
} }
json_schemas: dict = {}
request_data: dict = {} request_data: dict = {}
if provider == "cohere": if provider == "cohere":
if model.startswith("cohere.command-r"): if model.startswith("cohere.command-r"):
@ -223,57 +218,21 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
) )
request_data = {"prompt": prompt, **inference_params} request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic": elif provider == "anthropic":
if model.startswith("anthropic.claude-3"): return litellm.AmazonAnthropicClaude3Config().transform_request(
# Separate system prompt from rest of message model=model,
system_prompt_idx: list[int] = [] messages=messages,
system_messages: list[str] = [] optional_params=optional_params,
for idx, message in enumerate(messages): litellm_params=litellm_params,
if message["role"] == "system" and isinstance( headers=headers,
message["content"], str )
): elif provider == "nova":
system_messages.append(message["content"]) return litellm.AmazonInvokeNovaConfig().transform_request(
system_prompt_idx.append(idx) model=model,
if len(system_prompt_idx) > 0: messages=messages,
inference_params["system"] = "\n".join(system_messages) optional_params=optional_params,
messages = [ litellm_params=litellm_params,
i for j, i in enumerate(messages) if j not in system_prompt_idx headers=headers,
] )
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
) # type: ignore
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## Handle Tool Calling
if "tools" in inference_params:
_is_function_call = True
for tool in inference_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
inference_params["system"] = (
inference_params.get("system", "\n")
+ tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
inference_params.pop("tools")
request_data = {"messages": messages, **inference_params}
else:
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
request_data = {"prompt": prompt, **inference_params}
elif provider == "ai21": elif provider == "ai21":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config() config = litellm.AmazonAI21Config.get_config()
@ -347,6 +306,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
raise BedrockError( raise BedrockError(
message=raw_response.text, status_code=raw_response.status_code message=raw_response.text, status_code=raw_response.status_code
) )
verbose_logger.debug(
"bedrock invoke response % s",
json.dumps(completion_response, indent=4, default=str),
)
provider = self.get_bedrock_invoke_provider(model) provider = self.get_bedrock_invoke_provider(model)
outputText: Optional[str] = None outputText: Optional[str] = None
try: try:
@ -359,66 +322,31 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
completion_response["generations"][0]["finish_reason"] completion_response["generations"][0]["finish_reason"]
) )
elif provider == "anthropic": elif provider == "anthropic":
if model.startswith("anthropic.claude-3"): return litellm.AmazonAnthropicClaude3Config().transform_response(
json_schemas: dict = {} model=model,
_is_function_call = False raw_response=raw_response,
## Handle Tool Calling model_response=model_response,
if "tools" in optional_params: logging_obj=logging_obj,
_is_function_call = True request_data=request_data,
for tool in optional_params["tools"]: messages=messages,
json_schemas[tool["function"]["name"]] = tool[ optional_params=optional_params,
"function" litellm_params=litellm_params,
].get("parameters", None) encoding=encoding,
outputText = completion_response.get("content")[0].get("text", None) api_key=api_key,
if outputText is not None and contains_tag( json_mode=json_mode,
"invoke", outputText )
): # OUTPUT PARSE FUNCTION CALL elif provider == "nova":
function_name = extract_between_tags("tool_name", outputText)[0] return litellm.AmazonInvokeNovaConfig().transform_response(
function_arguments_str = extract_between_tags( model=model,
"invoke", outputText raw_response=raw_response,
)[0].strip() model_response=model_response,
function_arguments_str = ( logging_obj=logging_obj,
f"<invoke>{function_arguments_str}</invoke>" request_data=request_data,
) messages=messages,
function_arguments = parse_xml_params( optional_params=optional_params,
function_arguments_str, litellm_params=litellm_params,
json_schema=json_schemas.get( encoding=encoding,
function_name, None )
), # check if we have a json schema for this function name)
)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
model_response.choices[0].finish_reason = map_finish_reason(
completion_response.get("stop_reason", "")
)
_usage = litellm.Usage(
prompt_tokens=completion_response["usage"]["input_tokens"],
completion_tokens=completion_response["usage"]["output_tokens"],
total_tokens=completion_response["usage"]["input_tokens"]
+ completion_response["usage"]["output_tokens"],
)
setattr(model_response, "usage", _usage)
else:
outputText = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response[
"stop_reason"
]
elif provider == "ai21": elif provider == "ai21":
outputText = ( outputText = (
completion_response.get("completions")[0].get("data").get("text") completion_response.get("completions")[0].get("data").get("text")
@ -536,6 +464,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False, fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
), ),
model=model, model=model,
custom_llm_provider="bedrock", custom_llm_provider="bedrock",
@ -569,6 +498,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False, fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
), ),
model=model, model=model,
custom_llm_provider="bedrock", custom_llm_provider="bedrock",
@ -594,10 +524,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
""" """
Helper function to get the bedrock provider from the model Helper function to get the bedrock provider from the model
handles 2 scenarions: handles 3 scenarions:
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` 1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` 2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
""" """
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
_split_model = model.split(".")[0] _split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
@ -606,6 +541,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
provider = AmazonInvokeConfig._get_provider_from_model_path(model) provider = AmazonInvokeConfig._get_provider_from_model_path(model)
if provider is not None: if provider is not None:
return provider return provider
# check if provider == "nova"
if "nova" in model:
return "nova"
return None return None
@staticmethod @staticmethod
@ -640,16 +579,16 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
else: else:
modelId = model modelId = model
modelId = modelId.replace("invoke/", "", 1)
if provider == "llama" and "llama/" in modelId: if provider == "llama" and "llama/" in modelId:
modelId = self._get_model_id_for_llama_like_model(modelId) modelId = self._get_model_id_for_llama_like_model(modelId)
return modelId return modelId
def _get_aws_region_name(self, optional_params: dict) -> str: def _get_aws_region_name(self, optional_params: dict) -> str:
""" """
Get the AWS region name from the environment variables Get the AWS region name from the environment variables
""" """
aws_region_name = optional_params.pop("aws_region_name", None) aws_region_name = optional_params.get("aws_region_name", None)
### SET REGION NAME ### ### SET REGION NAME ###
if aws_region_name is None: if aws_region_name is None:
# check env # # check env #

View file

@ -3,11 +3,12 @@ Common utilities used across bedrock chat/embedding/image generation
""" """
import os import os
from typing import List, Optional, Union from typing import List, Literal, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
@ -310,3 +311,68 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
response_tool_name response_tool_name
] ]
return response_tool_name return response_tool_name
class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
@staticmethod
def get_base_model(model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/", 1)[1]
if model.startswith("invoke/"):
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if (
potential_region
in BedrockModelInfo._supported_cross_region_inference_region()
):
return model.split(".", 1)[1]
elif (
alt_potential_region in BedrockModelInfo.all_global_regions
and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model
@staticmethod
def _supported_cross_region_inference_region() -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]
@staticmethod
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
"""
Get the bedrock route for the given model.
"""
base_model = BedrockModelInfo.get_base_model(model)
if "invoke/" in model:
return "invoke"
elif "converse_like" in model:
return "converse_like"
elif "converse/" in model:
return "converse"
elif base_model in litellm.bedrock_converse_models:
return "converse"
return "invoke"

View file

@ -27,7 +27,7 @@ class AmazonTitanG1Config:
def __init__( def __init__(
self, self,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -33,7 +33,7 @@ class AmazonTitanV2Config:
def __init__( def __init__(
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -49,7 +49,7 @@ class AmazonStabilityConfig:
width: Optional[int] = None, width: Optional[int] = None,
height: Optional[int] = None, height: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -45,7 +45,7 @@ class ClarifaiConfig(BaseConfig):
temperature: Optional[int] = None, temperature: Optional[int] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -44,7 +44,7 @@ class CloudflareChatConfig(BaseConfig):
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -104,7 +104,7 @@ class CohereChatConfig(BaseConfig):
tool_results: Optional[list] = None, tool_results: Optional[list] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -86,7 +86,7 @@ class CohereTextConfig(BaseConfig):
return_likelihoods: Optional[str] = None, return_likelihoods: Optional[str] = None,
logit_bias: Optional[dict] = None, logit_bias: Optional[dict] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
import time
from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Optional, Union
import httpx import httpx
@ -179,6 +180,7 @@ class AsyncHTTPHandler:
stream: bool = False, stream: bool = False,
logging_obj: Optional[LiteLLMLoggingObject] = None, logging_obj: Optional[LiteLLMLoggingObject] = None,
): ):
start_time = time.time()
try: try:
if timeout is None: if timeout is None:
timeout = self.timeout timeout = self.timeout
@ -207,6 +209,8 @@ class AsyncHTTPHandler:
finally: finally:
await new_client.aclose() await new_client.aclose()
except httpx.TimeoutException as e: except httpx.TimeoutException as e:
end_time = time.time()
time_delta = round(end_time - start_time, 3)
headers = {} headers = {}
error_response = getattr(e, "response", None) error_response = getattr(e, "response", None)
if error_response is not None: if error_response is not None:
@ -214,7 +218,7 @@ class AsyncHTTPHandler:
headers["response_headers-{}".format(key)] = value headers["response_headers-{}".format(key)] = value
raise litellm.Timeout( raise litellm.Timeout(
message=f"Connection timed out after {timeout} seconds.", message=f"Connection timed out. Timeout passed={timeout}, time taken={time_delta} seconds",
model="default-model-name", model="default-model-name",
llm_provider="litellm-httpx-handler", llm_provider="litellm-httpx-handler",
headers=headers, headers=headers,

View file

@ -37,7 +37,7 @@ class DatabricksConfig(OpenAILikeChatConfig):
stop: Optional[Union[List[str], str]] = None, stop: Optional[Union[List[str], str]] = None,
n: Optional[int] = None, n: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -16,7 +16,7 @@ class DatabricksEmbeddingConfig:
) )
def __init__(self, instruction: Optional[str] = None) -> None: def __init__(self, instruction: Optional[str] = None) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -145,7 +145,7 @@ class AlephAlphaConfig:
contextual_control_threshold: Optional[int] = None, contextual_control_threshold: Optional[int] = None,
control_log_additive: Optional[bool] = None, control_log_additive: Optional[bool] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -63,7 +63,7 @@ class PalmConfig:
top_p: Optional[float] = None, top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -57,7 +57,7 @@ class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
candidate_count: Optional[int] = None, candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None, stop_sequences: Optional[list] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -77,7 +77,7 @@ class HuggingfaceChatConfig(BaseConfig):
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: Optional[bool] = None, watermark: Optional[bool] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -20,6 +20,15 @@ from .common_utils import InfinityError
class InfinityRerankConfig(CohereRerankConfig): class InfinityRerankConfig(CohereRerankConfig):
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError("api_base is required for Infinity rerank")
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/rerank"):
api_base = f"{api_base}/rerank"
return api_base
def validate_environment( def validate_environment(
self, self,
headers: dict, headers: dict,

View file

@ -21,7 +21,7 @@ class JinaAIEmbeddingConfig:
def __init__( def __init__(
self, self,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -18,7 +18,7 @@ class LmStudioEmbeddingConfig:
def __init__( def __init__(
self, self,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -33,7 +33,7 @@ class MaritalkConfig(OpenAIGPTConfig):
tools: Optional[List[dict]] = None, tools: Optional[List[dict]] = None,
tool_choice: Optional[Union[str, dict]] = None, tool_choice: Optional[Union[str, dict]] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -78,7 +78,7 @@ class NLPCloudConfig(BaseConfig):
num_beams: Optional[int] = None, num_beams: Optional[int] = None,
num_return_sequences: Optional[int] = None, num_return_sequences: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -32,7 +32,7 @@ class NvidiaNimEmbeddingConfig:
input_type: Optional[str] = None, input_type: Optional[str] = None,
truncate: Optional[str] = None, truncate: Optional[str] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@ -58,7 +58,7 @@ class NvidiaNimEmbeddingConfig:
def get_supported_openai_params( def get_supported_openai_params(
self, self,
): ):
return ["encoding_format", "user"] return ["encoding_format", "user", "dimensions"]
def map_openai_params( def map_openai_params(
self, self,
@ -73,6 +73,8 @@ class NvidiaNimEmbeddingConfig:
optional_params["extra_body"].update({"input_type": v}) optional_params["extra_body"].update({"input_type": v})
elif k == "truncate": elif k == "truncate":
optional_params["extra_body"].update({"truncate": v}) optional_params["extra_body"].update({"truncate": v})
else:
optional_params[k] = v
if kwargs is not None: if kwargs is not None:
# pass kwargs in extra_body # pass kwargs in extra_body

View file

@ -117,7 +117,7 @@ class OllamaConfig(BaseConfig):
system: Optional[str] = None, system: Optional[str] = None,
template: Optional[str] = None, template: Optional[str] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -105,7 +105,7 @@ class OllamaChatConfig(OpenAIGPTConfig):
system: Optional[str] = None, system: Optional[str] = None,
template: Optional[str] = None, template: Optional[str] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -344,6 +344,10 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
or "https://api.openai.com/v1" or "https://api.openai.com/v1"
) )
@staticmethod
def get_base_model(model: str) -> str:
return model
def get_model_response_iterator( def get_model_response_iterator(
self, self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],

View file

@ -54,7 +54,7 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
if model is None: if model is None:
return True return True
supported_stream_models = ["o1-mini", "o1-preview"] supported_stream_models = ["o1-mini", "o1-preview", "o3-mini"]
for supported_model in supported_stream_models: for supported_model in supported_stream_models:
if supported_model in model: if supported_model in model:
return False return False

View file

@ -27,6 +27,7 @@ from typing_extensions import overload
import litellm import litellm
from litellm import LlmProviders from litellm import LlmProviders
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.constants import DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
@ -320,6 +321,17 @@ class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def _set_dynamic_params_on_client(
self,
client: Union[OpenAI, AsyncOpenAI],
organization: Optional[str] = None,
max_retries: Optional[int] = None,
):
if organization is not None:
client.organization = organization
if max_retries is not None:
client.max_retries = max_retries
def _get_openai_client( def _get_openai_client(
self, self,
is_async: bool, is_async: bool,
@ -327,11 +339,10 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2, max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
organization: Optional[str] = None, organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
): ):
args = locals()
if client is None: if client is None:
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError( raise OpenAIError(
@ -364,7 +375,6 @@ class OpenAIChatCompletion(BaseLLM):
organization=organization, organization=organization,
) )
else: else:
_new_client = OpenAI( _new_client = OpenAI(
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
@ -383,6 +393,11 @@ class OpenAIChatCompletion(BaseLLM):
return _new_client return _new_client
else: else:
self._set_dynamic_params_on_client(
client=client,
organization=organization,
max_retries=max_retries,
)
return client return client
@track_llm_api_timing() @track_llm_api_timing()

View file

@ -58,7 +58,7 @@ class PetalsConfig(BaseConfig):
top_p: Optional[float] = None, top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -59,7 +59,7 @@ class PredibaseConfig(BaseConfig):
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: Optional[bool] = None, watermark: Optional[bool] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -73,7 +73,7 @@ class ReplicateConfig(BaseConfig):
seed: Optional[int] = None, seed: Optional[int] = None,
debug: Optional[bool] = None, debug: Optional[bool] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -47,7 +47,7 @@ class SagemakerConfig(BaseConfig):
temperature: Optional[float] = None, temperature: Optional[float] = None,
return_full_text: Optional[bool] = None, return_full_text: Optional[bool] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -29,3 +29,7 @@ class TopazModelInfo(BaseLLMModelInfo):
return ( return (
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com" api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
) )
@staticmethod
def get_base_model(model: str) -> str:
return model

View file

@ -179,7 +179,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -17,7 +17,7 @@ class VertexAIAi21Config:
self, self,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -21,7 +21,7 @@ class VertexAILlama3Config:
self, self,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key == "max_tokens" and value is None: if key == "max_tokens" and value is None:
value = self.max_tokens value = self.max_tokens

View file

@ -48,7 +48,7 @@ class VertexAITextEmbeddingConfig(BaseModel):
] = None, ] = None,
title: Optional[str] = None, title: Optional[str] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -108,7 +108,7 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
stream: Optional[bool] = None, stream: Optional[bool] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
locals_ = locals() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)

View file

@ -68,6 +68,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response, get_content_from_model_response,
) )
from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.realtime_api.main import _realtime_health_check from litellm.realtime_api.main import _realtime_health_check
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
@ -1222,6 +1223,8 @@ def completion( # type: ignore # noqa: PLR0915
if extra_headers is not None: if extra_headers is not None:
optional_params["extra_headers"] = extra_headers optional_params["extra_headers"] = extra_headers
if max_retries is not None:
optional_params["max_retries"] = max_retries
if litellm.AzureOpenAIO1Config().is_o_series_model(model=model): if litellm.AzureOpenAIO1Config().is_o_series_model(model=model):
@ -2626,11 +2629,8 @@ def completion( # type: ignore # noqa: PLR0915
aws_bedrock_client.meta.region_name aws_bedrock_client.meta.region_name
) )
base_model = litellm.AmazonConverseConfig()._get_base_model(model) bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse":
if base_model in litellm.bedrock_converse_models or model.startswith(
"converse/"
):
model = model.replace("converse/", "") model = model.replace("converse/", "")
response = bedrock_converse_chat_completion.completion( response = bedrock_converse_chat_completion.completion(
model=model, model=model,
@ -2649,7 +2649,7 @@ def completion( # type: ignore # noqa: PLR0915
client=client, client=client,
api_base=api_base, api_base=api_base,
) )
elif "converse_like" in model: elif bedrock_route == "converse_like":
model = model.replace("converse_like/", "") model = model.replace("converse_like/", "")
response = base_llm_http_handler.completion( response = base_llm_http_handler.completion(
model=model, model=model,
@ -3947,6 +3947,7 @@ async def atext_completion(
), ),
model=model, model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
stream_options=kwargs.get('stream_options'),
) )
else: else:
## OpenAI / Azure Text Completion Returns here ## OpenAI / Azure Text Completion Returns here

View file

@ -1069,6 +1069,21 @@
"supports_prompt_caching": true, "supports_prompt_caching": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
"azure/o1-2024-12-17": {
"max_tokens": 100000,
"max_input_tokens": 200000,
"max_output_tokens": 100000,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000060,
"cache_read_input_token_cost": 0.0000075,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_prompt_caching": true,
"supports_tool_choice": true
},
"azure/o1-preview": { "azure/o1-preview": {
"max_tokens": 32768, "max_tokens": 32768,
"max_input_tokens": 128000, "max_input_tokens": 128000,
@ -1397,6 +1412,19 @@
"deprecation_date": "2025-03-31", "deprecation_date": "2025-03-31",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"azure/gpt-3.5-turbo-0125": {
"max_tokens": 4096,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"deprecation_date": "2025-03-31",
"supports_tool_choice": true
},
"azure/gpt-35-turbo-16k": { "azure/gpt-35-turbo-16k": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 16385, "max_input_tokens": 16385,
@ -1418,6 +1446,17 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
"azure/gpt-3.5-turbo": {
"max_tokens": 4096,
"max_input_tokens": 4097,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true
},
"azure/gpt-3.5-turbo-instruct-0914": { "azure/gpt-3.5-turbo-instruct-0914": {
"max_tokens": 4097, "max_tokens": 4097,
"max_input_tokens": 4097, "max_input_tokens": 4097,
@ -2174,11 +2213,11 @@
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 65536, "max_input_tokens": 65536,
"max_output_tokens": 8192, "max_output_tokens": 8192,
"input_cost_per_token": 0.00000014, "input_cost_per_token": 0.00000027,
"input_cost_per_token_cache_hit": 0.000000014, "input_cost_per_token_cache_hit": 0.00000007,
"cache_read_input_token_cost": 0.000000014, "cache_read_input_token_cost": 0.00000007,
"cache_creation_input_token_cost": 0.0, "cache_creation_input_token_cost": 0.0,
"output_cost_per_token": 0.00000028, "output_cost_per_token": 0.0000011,
"litellm_provider": "deepseek", "litellm_provider": "deepseek",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
@ -3650,9 +3689,34 @@
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_audio_output": true, "supports_audio_output": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"gemini/gemini-2.0-flash": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_audio_token": 0.0000007,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000004,
"litellm_provider": "gemini",
"mode": "chat",
"rpm": 10000,
"tpm": 10000000,
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash"
},
"gemini-2.0-flash-001": { "gemini-2.0-flash-001": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1048576, "max_input_tokens": 1048576,
@ -3663,9 +3727,9 @@
"max_audio_length_hours": 8.4, "max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1, "max_audio_per_prompt": 1,
"max_pdf_size_mb": 30, "max_pdf_size_mb": 30,
"input_cost_per_audio_token": 0.001, "input_cost_per_audio_token": 0.000001,
"input_cost_per_token": 0.00015, "input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.0006, "output_cost_per_token": 0.0000006,
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_system_messages": true, "supports_system_messages": true,
@ -3674,7 +3738,7 @@
"supports_response_schema": true, "supports_response_schema": true,
"supports_audio_output": true, "supports_audio_output": true,
"supports_tool_choice": true, "supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
}, },
"gemini-2.0-flash-thinking-exp": { "gemini-2.0-flash-thinking-exp": {
"max_tokens": 8192, "max_tokens": 8192,
@ -3744,6 +3808,31 @@
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"gemini/gemini-2.0-flash-001": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_audio_token": 0.0000007,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000004,
"litellm_provider": "gemini",
"mode": "chat",
"rpm": 10000,
"tpm": 10000000,
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash"
},
"gemini/gemini-2.0-flash-exp": { "gemini/gemini-2.0-flash-exp": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1048576, "max_input_tokens": 1048576,
@ -3780,6 +3869,31 @@
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"gemini/gemini-2.0-flash-lite-preview-02-05": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_audio_token": 0.000000075,
"input_cost_per_token": 0.000000075,
"output_cost_per_token": 0.0000003,
"litellm_provider": "gemini",
"mode": "chat",
"rpm": 60000,
"tpm": 10000000,
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash-lite"
},
"gemini/gemini-2.0-flash-thinking-exp": { "gemini/gemini-2.0-flash-thinking-exp": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1048576, "max_input_tokens": 1048576,
@ -6026,7 +6140,8 @@
"litellm_provider": "bedrock_converse", "litellm_provider": "bedrock_converse",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_prompt_caching": true "supports_prompt_caching": true,
"supports_response_schema": true
}, },
"us.amazon.nova-micro-v1:0": { "us.amazon.nova-micro-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -6051,7 +6166,8 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true, "supports_pdf_input": true,
"supports_prompt_caching": true "supports_prompt_caching": true,
"supports_response_schema": true
}, },
"us.amazon.nova-lite-v1:0": { "us.amazon.nova-lite-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -6064,7 +6180,8 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true, "supports_pdf_input": true,
"supports_prompt_caching": true "supports_prompt_caching": true,
"supports_response_schema": true
}, },
"amazon.nova-pro-v1:0": { "amazon.nova-pro-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -6077,7 +6194,8 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true, "supports_pdf_input": true,
"supports_prompt_caching": true "supports_prompt_caching": true,
"supports_response_schema": true
}, },
"us.amazon.nova-pro-v1:0": { "us.amazon.nova-pro-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -6090,7 +6208,8 @@
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_pdf_input": true, "supports_pdf_input": true,
"supports_prompt_caching": true "supports_prompt_caching": true,
"supports_response_schema": true
}, },
"anthropic.claude-3-sonnet-20240229-v1:0": { "anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -6101,6 +6220,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6113,6 +6233,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6140,6 +6261,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6153,6 +6275,7 @@
"mode": "chat", "mode": "chat",
"supports_assistant_prefill": true, "supports_assistant_prefill": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_prompt_caching": true, "supports_prompt_caching": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6165,6 +6288,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6177,6 +6301,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6189,6 +6314,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6216,6 +6342,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6230,6 +6357,7 @@
"supports_assistant_prefill": true, "supports_assistant_prefill": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_prompt_caching": true, "supports_prompt_caching": true,
"supports_response_schema": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
"us.anthropic.claude-3-opus-20240229-v1:0": { "us.anthropic.claude-3-opus-20240229-v1:0": {
@ -6241,6 +6369,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6253,6 +6382,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6265,6 +6395,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6292,6 +6423,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -6318,6 +6450,7 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_response_schema": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true "supports_tool_choice": true
}, },
@ -8935,4 +9068,4 @@
"output_cost_per_second": 0.00, "output_cost_per_second": 0.00,
"litellm_provider": "assemblyai" "litellm_provider": "assemblyai"
} }
} }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show more