Merge branch 'main' into litellm_fix_track_cost_callback

This commit is contained in:
Ishaan Jaff 2024-11-12 10:29:40 -08:00
commit 7973665219
250 changed files with 17396 additions and 4808 deletions

View file

@ -47,7 +47,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.52.0
pip install openai==1.54.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -103,7 +103,7 @@ jobs:
command: |
pwd
ls
python -m pytest -vv tests/local_testing --cov=litellm --cov-report=xml -x --junitxml=test-results/junit.xml --durations=5 -k "not test_python_38.py and not router and not assistants"
python -m pytest -vv tests/local_testing --cov=litellm --cov-report=xml -x --junitxml=test-results/junit.xml --durations=5 -k "not test_python_38.py and not router and not assistants and not langfuse and not caching and not cache"
no_output_timeout: 120m
- run:
name: Rename the coverage files
@ -119,6 +119,204 @@ jobs:
paths:
- local_testing_coverage.xml
- local_testing_coverage
langfuse_logging_unit_tests:
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project
steps:
- checkout
- run:
name: Show git commit hash
command: |
echo "Git commit hash: $CIRCLE_SHA1"
- restore_cache:
keys:
- v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install -r .circleci/requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install "pytest-cov==5.0.0"
pip install mypy
pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0"
pip install pyarrow
pip install "boto3==1.34.34"
pip install "aioboto3==12.3.0"
pip install langchain
pip install lunary==0.2.5
pip install "azure-identity==1.16.1"
pip install "langfuse==2.45.0"
pip install "logfire==0.29.0"
pip install numpydoc
pip install traceloop-sdk==0.21.1
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.54.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
pip install "respx==0.21.1"
pip install fastapi
pip install "gunicorn==21.2.0"
pip install "anyio==4.2.0"
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "apscheduler==3.10.4"
pip install "PyGithub==1.59.1"
pip install argon2-cffi
pip install "pytest-mock==3.12.0"
pip install python-multipart
pip install google-cloud-aiplatform
pip install prometheus-client==0.20.0
pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0"
pip install "jsonschema==4.22.0"
- save_cache:
paths:
- ./venv
key: v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
- run:
name: Run prisma ./docker/entrypoint.sh
command: |
set +e
chmod +x docker/entrypoint.sh
./docker/entrypoint.sh
set -e
# Run pytest and generate JUnit XML report
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/local_testing --cov=litellm --cov-report=xml -x --junitxml=test-results/junit.xml --durations=5 -k "langfuse"
no_output_timeout: 120m
- run:
name: Rename the coverage files
command: |
mv coverage.xml langfuse_coverage.xml
mv .coverage langfuse_coverage
# Store test results
- store_test_results:
path: test-results
- persist_to_workspace:
root: .
paths:
- langfuse_coverage.xml
- langfuse_coverage
caching_unit_tests:
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project
steps:
- checkout
- run:
name: Show git commit hash
command: |
echo "Git commit hash: $CIRCLE_SHA1"
- restore_cache:
keys:
- v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install -r .circleci/requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install "pytest-cov==5.0.0"
pip install mypy
pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0"
pip install pyarrow
pip install "boto3==1.34.34"
pip install "aioboto3==12.3.0"
pip install langchain
pip install lunary==0.2.5
pip install "azure-identity==1.16.1"
pip install "langfuse==2.45.0"
pip install "logfire==0.29.0"
pip install numpydoc
pip install traceloop-sdk==0.21.1
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.54.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
pip install "respx==0.21.1"
pip install fastapi
pip install "gunicorn==21.2.0"
pip install "anyio==4.2.0"
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "apscheduler==3.10.4"
pip install "PyGithub==1.59.1"
pip install argon2-cffi
pip install "pytest-mock==3.12.0"
pip install python-multipart
pip install google-cloud-aiplatform
pip install prometheus-client==0.20.0
pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0"
pip install "jsonschema==4.22.0"
- save_cache:
paths:
- ./venv
key: v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
- run:
name: Run prisma ./docker/entrypoint.sh
command: |
set +e
chmod +x docker/entrypoint.sh
./docker/entrypoint.sh
set -e
# Run pytest and generate JUnit XML report
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/local_testing --cov=litellm --cov-report=xml -x --junitxml=test-results/junit.xml --durations=5 -k "caching or cache"
no_output_timeout: 120m
- run:
name: Rename the coverage files
command: |
mv coverage.xml caching_coverage.xml
mv .coverage caching_coverage
# Store test results
- store_test_results:
path: test-results
- persist_to_workspace:
root: .
paths:
- caching_coverage.xml
- caching_coverage
auth_ui_unit_tests:
docker:
- image: cimg/python:3.11
@ -215,6 +413,105 @@ jobs:
paths:
- litellm_router_coverage.xml
- litellm_router_coverage
litellm_proxy_unit_testing: # Runs all tests with the "proxy", "key", "jwt" filenames
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project
steps:
- checkout
- run:
name: Show git commit hash
command: |
echo "Git commit hash: $CIRCLE_SHA1"
- restore_cache:
keys:
- v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install -r .circleci/requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install "pytest-cov==5.0.0"
pip install mypy
pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0"
pip install pyarrow
pip install "boto3==1.34.34"
pip install "aioboto3==12.3.0"
pip install langchain
pip install lunary==0.2.5
pip install "azure-identity==1.16.1"
pip install "langfuse==2.45.0"
pip install "logfire==0.29.0"
pip install numpydoc
pip install traceloop-sdk==0.21.1
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.54.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
pip install "respx==0.21.1"
pip install fastapi
pip install "gunicorn==21.2.0"
pip install "anyio==4.2.0"
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "apscheduler==3.10.4"
pip install "PyGithub==1.59.1"
pip install argon2-cffi
pip install "pytest-mock==3.12.0"
pip install python-multipart
pip install google-cloud-aiplatform
pip install prometheus-client==0.20.0
pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0"
pip install "jsonschema==4.22.0"
- save_cache:
paths:
- ./venv
key: v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
- run:
name: Run prisma ./docker/entrypoint.sh
command: |
set +e
chmod +x docker/entrypoint.sh
./docker/entrypoint.sh
set -e
# Run pytest and generate JUnit XML report
- run:
name: Run tests
command: |
pwd
ls
python -m pytest tests/proxy_unit_tests --cov=litellm --cov-report=xml -vv -x -v --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
- run:
name: Rename the coverage files
command: |
mv coverage.xml litellm_proxy_unit_tests_coverage.xml
mv .coverage litellm_proxy_unit_tests_coverage
# Store test results
- store_test_results:
path: test-results
- persist_to_workspace:
root: .
paths:
- litellm_proxy_unit_tests_coverage.xml
- litellm_proxy_unit_tests_coverage
litellm_assistants_api_testing: # Runs all tests with the "assistants" keyword
docker:
- image: cimg/python:3.11
@ -328,7 +625,7 @@ jobs:
paths:
- llm_translation_coverage.xml
- llm_translation_coverage
logging_testing:
image_gen_testing:
docker:
- image: cimg/python:3.11
auth:
@ -349,6 +646,51 @@ jobs:
pip install "pytest-asyncio==0.21.1"
pip install "respx==0.21.1"
# Run pytest and generate JUnit XML report
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/image_gen_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
- run:
name: Rename the coverage files
command: |
mv coverage.xml image_gen_coverage.xml
mv .coverage image_gen_coverage
# Store test results
- store_test_results:
path: test-results
- persist_to_workspace:
root: .
paths:
- image_gen_coverage.xml
- image_gen_coverage
logging_testing:
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project
steps:
- checkout
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-cov==5.0.0"
pip install "pytest-asyncio==0.21.1"
pip install pytest-mock
pip install "respx==0.21.1"
pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0"
# Run pytest and generate JUnit XML report
- run:
name: Run tests
command: |
@ -392,7 +734,7 @@ jobs:
pip install click
pip install "boto3==1.34.34"
pip install jinja2
pip install tokenizers
pip install tokenizers=="0.20.0"
pip install jsonschema
- run:
name: Run tests
@ -425,6 +767,7 @@ jobs:
- run: python ./tests/documentation_tests/test_general_setting_keys.py
- run: python ./tests/code_coverage_tests/router_code_coverage.py
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
- run: python ./tests/documentation_tests/test_env_keys.py
- run: helm lint ./deploy/charts/litellm-helm
@ -520,7 +863,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.52.0"
pip install "openai==1.54.0 "
# Run pytest and generate JUnit XML report
- run:
name: Build Docker image
@ -577,7 +920,7 @@ jobs:
command: |
pwd
ls
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/image_gen_tests
no_output_timeout: 120m
# Store test results
@ -637,7 +980,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.52.0"
pip install "openai==1.54.0 "
- run:
name: Build Docker image
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
@ -664,6 +1007,7 @@ jobs:
-e AWS_REGION_NAME=$AWS_REGION_NAME \
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
-e COHERE_API_KEY=$COHERE_API_KEY \
-e GCS_FLUSH_INTERVAL="1" \
--name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
-v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \
@ -729,7 +1073,7 @@ jobs:
pip install "pytest-asyncio==0.21.1"
pip install "google-cloud-aiplatform==1.43.0"
pip install aiohttp
pip install "openai==1.52.0"
pip install "openai==1.54.0 "
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
pip install "pytest==7.3.1"
@ -814,7 +1158,7 @@ jobs:
python -m venv venv
. venv/bin/activate
pip install coverage
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage
coverage xml
- codecov/upload:
file: ./coverage.xml
@ -924,7 +1268,7 @@ jobs:
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install aiohttp
pip install "openai==1.52.0"
pip install "openai==1.54.0 "
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
pip install "pytest==7.3.1"
@ -986,6 +1330,41 @@ jobs:
- store_test_results:
path: test-results
test_bad_database_url:
machine:
image: ubuntu-2204:2023.10.1
resource_class: xlarge
working_directory: ~/project
steps:
- checkout
- run:
name: Build Docker image
command: |
docker build -t myapp . -f ./docker/Dockerfile.non_root
- run:
name: Run Docker container with bad DATABASE_URL
command: |
docker run --name my-app \
-p 4000:4000 \
-e DATABASE_URL="postgresql://wrong:wrong@wrong:5432/wrong" \
myapp:latest \
--port 4000 > docker_output.log 2>&1 || true
- run:
name: Display Docker logs
command: cat docker_output.log
- run:
name: Check for expected error
command: |
if grep -q "Error: P1001: Can't reach database server at" docker_output.log && \
grep -q "httpx.ConnectError: All connection attempts failed" docker_output.log && \
grep -q "ERROR: Application startup failed. Exiting." docker_output.log; then
echo "Expected error found. Test passed."
else
echo "Expected error not found. Test failed."
cat docker_output.log
exit 1
fi
workflows:
version: 2
build_and_test:
@ -996,6 +1375,24 @@ workflows:
only:
- main
- /litellm_.*/
- langfuse_logging_unit_tests:
filters:
branches:
only:
- main
- /litellm_.*/
- caching_unit_tests:
filters:
branches:
only:
- main
- /litellm_.*/
- litellm_proxy_unit_testing:
filters:
branches:
only:
- main
- /litellm_.*/
- litellm_assistants_api_testing:
filters:
branches:
@ -1050,6 +1447,12 @@ workflows:
only:
- main
- /litellm_.*/
- image_gen_testing:
filters:
branches:
only:
- main
- /litellm_.*/
- logging_testing:
filters:
branches:
@ -1059,8 +1462,12 @@ workflows:
- upload-coverage:
requires:
- llm_translation_testing
- image_gen_testing
- logging_testing
- litellm_router_testing
- caching_unit_tests
- litellm_proxy_unit_testing
- langfuse_logging_unit_tests
- local_testing
- litellm_assistants_api_testing
- auth_ui_unit_tests
@ -1082,18 +1489,29 @@ workflows:
only:
- main
- /litellm_.*/
- test_bad_database_url:
filters:
branches:
only:
- main
- /litellm_.*/
- publish_to_pypi:
requires:
- local_testing
- build_and_test
- load_testing
- test_bad_database_url
- llm_translation_testing
- image_gen_testing
- logging_testing
- litellm_router_testing
- caching_unit_tests
- langfuse_logging_unit_tests
- litellm_assistants_api_testing
- auth_ui_unit_tests
- db_migration_disable_update_check
- e2e_ui_testing
- litellm_proxy_unit_testing
- installing_litellm_on_python
- proxy_logging_guardrails_model_info_tests
- proxy_pass_through_endpoint_tests

View file

@ -1,5 +1,5 @@
# used by CI/CD testing
openai==1.52.0
openai==1.54.0
python-dotenv
tiktoken
importlib_metadata

View file

@ -1,35 +0,0 @@
name: Lint
# If a pull-request is pushed then cancel all previously running jobs related
# to that pull-request
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true
on:
push:
branches:
- main
pull_request:
branches:
- main
- master
env:
PY_COLORS: 1
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4.1.1
with:
fetch-depth: 0
- name: Install poetry
run: pipx install poetry
- uses: actions/setup-python@v5.0.0
with:
python-version: '3.12'
- run: poetry install
- name: Run flake8 (https://flake8.pycqa.org/en/latest/)
run: poetry run flake8

View file

@ -4,7 +4,6 @@ python script to pre-create all views required by LiteLLM Proxy Server
import asyncio
import os
from update_unassigned_teams import apply_db_fixes
# Enter your DATABASE_URL here
@ -205,7 +204,6 @@ async def check_view_exists(): # noqa: PLR0915
print("Last30dTopEndUsersSpend Created!") # noqa
await apply_db_fixes(db=db)
return

View file

@ -1,7 +1,14 @@
from prisma import Prisma
from litellm._logging import verbose_logger
async def apply_db_fixes(db: Prisma):
"""
Do Not Run this in production, only use it as a one-time fix
"""
verbose_logger.warning(
"DO NOT run this in Production....Running update_unassigned_teams"
)
try:
sql_query = """
UPDATE "LiteLLM_SpendLogs"

View file

@ -18,3 +18,4 @@
npm-debug.log*
yarn-debug.log*
yarn-error.log*
yarn.lock

View file

@ -0,0 +1,41 @@
# Benchmarks
Benchmarks for LiteLLM Gateway (Proxy Server)
Locust Settings:
- 2500 Users
- 100 user Ramp Up
## Basic Benchmarks
Overhead when using a Deployed Proxy vs Direct to LLM
- Latency overhead added by LiteLLM Proxy: 107ms
| Metric | Direct to Fake Endpoint | Basic Litellm Proxy |
|--------|------------------------|---------------------|
| RPS | 1196 | 1133.2 |
| Median Latency (ms) | 33 | 140 |
## Logging Callbacks
### [GCS Bucket Logging](https://docs.litellm.ai/docs/proxy/bucket)
Using GCS Bucket has **no impact on latency, RPS compared to Basic Litellm Proxy**
| Metric | Basic Litellm Proxy | LiteLLM Proxy with GCS Bucket Logging |
|--------|------------------------|---------------------|
| RPS | 1133.2 | 1137.3 |
| Median Latency (ms) | 140 | 138 |
### [LangSmith logging](https://docs.litellm.ai/docs/proxy/logging)
Using LangSmith has **no impact on latency, RPS compared to Basic Litellm Proxy**
| Metric | Basic Litellm Proxy | LiteLLM Proxy with LangSmith |
|--------|------------------------|---------------------|
| RPS | 1133.2 | 1135 |
| Median Latency (ms) | 140 | 132 |

View file

@ -0,0 +1,109 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Predicted Outputs
| Property | Details |
|-------|-------|
| Description | Use this when most of the output of the LLM is known ahead of time. For instance, if you are asking the model to rewrite some text or code with only minor changes, you can reduce your latency significantly by using Predicted Outputs, passing in the existing content as your prediction. |
| Supported providers | `openai` |
| Link to OpenAI doc on Predicted Outputs | [Predicted Outputs ↗](https://platform.openai.com/docs/guides/latency-optimization#use-predicted-outputs) |
| Supported from LiteLLM Version | `v1.51.4` |
## Using Predicted Outputs
<Tabs>
<TabItem label="LiteLLM Python SDK" value="Python">
In this example we want to refactor a piece of C# code, and convert the Username property to Email instead:
```python
import litellm
os.environ["OPENAI_API_KEY"] = "your-api-key"
code = """
/// <summary>
/// Represents a user with a first name, last name, and username.
/// </summary>
public class User
{
/// <summary>
/// Gets or sets the user's first name.
/// </summary>
public string FirstName { get; set; }
/// <summary>
/// Gets or sets the user's last name.
/// </summary>
public string LastName { get; set; }
/// <summary>
/// Gets or sets the user's username.
/// </summary>
public string Username { get; set; }
}
"""
completion = litellm.completion(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
},
{"role": "user", "content": code},
],
prediction={"type": "content", "content": code},
)
print(completion)
```
</TabItem>
<TabItem label="LiteLLM Proxy Server" value="proxy">
1. Define models on config.yaml
```yaml
model_list:
- model_name: gpt-4o-mini # OpenAI gpt-4o-mini
litellm_params:
model: openai/gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY
```
2. Run proxy server
```bash
litellm --config config.yaml
```
3. Test it using the OpenAI Python SDK
```python
from openai import OpenAI
client = OpenAI(
api_key="LITELLM_PROXY_KEY", # sk-1234
base_url="LITELLM_PROXY_BASE" # http://0.0.0.0:4000
)
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
},
{"role": "user", "content": code},
],
prediction={"type": "content", "content": code},
)
print(completion)
```
</TabItem>
</Tabs>

View file

@ -35,7 +35,7 @@ OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
```shell
OTEL_EXPORTER="otlp_http"
OTEL_ENDPOINT="http:/0.0.0.0:4317"
OTEL_ENDPOINT="http://0.0.0.0:4318"
```
</TabItem>
@ -44,14 +44,24 @@ OTEL_ENDPOINT="http:/0.0.0.0:4317"
```shell
OTEL_EXPORTER="otlp_grpc"
OTEL_ENDPOINT="http:/0.0.0.0:4317"
OTEL_ENDPOINT="http://0.0.0.0:4317"
```
</TabItem>
<TabItem value="laminar" label="Log to Laminar">
```shell
OTEL_EXPORTER="otlp_grpc"
OTEL_ENDPOINT="https://api.lmnr.ai:8443"
OTEL_HEADERS="authorization=Bearer <project-api-key>"
```
</TabItem>
</Tabs>
Use just 2 lines of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
Use just 1 line of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
```python
litellm.callbacks = ["otel"]

View file

@ -864,3 +864,96 @@ Human: How do I boil water?
Assistant:
```
## Usage - PDF
Pass base64 encoded PDF files to Anthropic models using the `image_url` field.
<Tabs>
<TabItem value="sdk" label="SDK">
### **using base64**
```python
from litellm import completion, supports_pdf_input
import base64
import requests
# URL of the file
url = "https://storage.googleapis.com/cloud-samples-data/generative-ai/pdf/2403.05530.pdf"
# Download the file
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
## check if model supports pdf input - (2024/11/11) only claude-3-5-haiku-20241022 supports it
supports_pdf_input("anthropic/claude-3-5-haiku-20241022") # True
response = completion(
model="anthropic/claude-3-5-haiku-20241022",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "You are a very professional document summarization specialist. Please summarize the given document."},
{
"type": "image_url",
"image_url": f"data:application/pdf;base64,{encoded_file}", # 👈 PDF
},
],
}
],
max_tokens=300,
)
print(response.choices[0])
```
</TabItem>
<TabItem value="proxy" lable="PROXY">
1. Add model to config
```yaml
- model_name: claude-3-5-haiku-20241022
litellm_params:
model: anthropic/claude-3-5-haiku-20241022
api_key: os.environ/ANTHROPIC_API_KEY
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
-d '{
"model": "claude-3-5-haiku-20241022",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "You are a very professional document summarization specialist. Please summarize the given document"
},
{
"type": "image_url",
"image_url": "data:application/pdf;base64,{encoded_file}" # 👈 PDF
}
}
]
}
],
"max_tokens": 300
}'
```
</TabItem>
</Tabs>

View file

@ -1082,5 +1082,6 @@ print(f"response: {response}")
| Model Name | Function Call |
|----------------------|---------------------------------------------|
| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` |
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` |
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` |

View file

@ -9,7 +9,7 @@ LiteLLM Supports Logging to the following Cloud Buckets
- (Enterprise) ✨ [Google Cloud Storage Buckets](#logging-proxy-inputoutput-to-google-cloud-storage-buckets)
- (Free OSS) [Amazon s3 Buckets](#logging-proxy-inputoutput---s3-buckets)
## Logging Proxy Input/Output to Google Cloud Storage Buckets
## Google Cloud Storage Buckets
Log LLM Logs to [Google Cloud Storage Buckets](https://cloud.google.com/storage?hl=en)
@ -20,6 +20,14 @@ Log LLM Logs to [Google Cloud Storage Buckets](https://cloud.google.com/storage?
:::
| Property | Details |
|----------|---------|
| Description | Log LLM Input/Output to cloud storage buckets |
| Load Test Benchmarks | [Benchmarks](https://docs.litellm.ai/docs/benchmarks) |
| Google Docs on Cloud Storage | [Google Cloud Storage](https://cloud.google.com/storage?hl=en) |
### Usage
1. Add `gcs_bucket` to LiteLLM Config.yaml
@ -85,7 +93,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
6. Save the JSON file and add the path to `GCS_PATH_SERVICE_ACCOUNT`
## Logging Proxy Input/Output - s3 Buckets
## s3 Buckets
We will use the `--config` to set

View file

@ -692,9 +692,13 @@ general_settings:
allowed_routes: ["route1", "route2"] # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
key_management_system: google_kms # either google_kms or azure_kms
master_key: string
# Database Settings
database_url: string
database_connection_pool_limit: 0 # default 100
database_connection_timeout: 0 # default 60s
allow_requests_on_db_unavailable: boolean # if true, will allow requests that can not connect to the DB to verify Virtual Key to still work
custom_auth: string
max_parallel_requests: 0 # the max parallel requests allowed per deployment
global_max_parallel_requests: 0 # the max parallel requests allowed on the proxy all up
@ -766,6 +770,7 @@ general_settings:
| database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) |
| database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) |
| database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) |
| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key |
| custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) |
| max_parallel_requests | integer | The max parallel requests allowed per deployment |
| global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall |
@ -929,6 +934,8 @@ router_settings:
| EMAIL_SUPPORT_CONTACT | Support contact email address
| GCS_BUCKET_NAME | Name of the Google Cloud Storage bucket
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS. **Default is 20 seconds**
| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs. **Default is 2048**
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
| GENERIC_CLIENT_ID | Client ID for generic OAuth providers
| GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers

View file

@ -107,7 +107,7 @@ class StandardLoggingModelInformation(TypedDict):
model_map_value: Optional[ModelInfo]
```
## Logging Proxy Input/Output - Langfuse
## Langfuse
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment
@ -463,7 +463,7 @@ You will see `raw_request` in your Langfuse Metadata. This is the RAW CURL comma
<Image img={require('../../img/debug_langfuse.png')} />
## Logging Proxy Input/Output in OpenTelemetry format
## OpenTelemetry format
:::info
@ -1216,7 +1216,7 @@ litellm_settings:
Start the LiteLLM Proxy and make a test request to verify the logs reached your callback API
## Logging LLM IO to Langsmith
## Langsmith
1. Set `success_callback: ["langsmith"]` on litellm config.yaml
@ -1261,7 +1261,7 @@ Expect to see your log on Langfuse
<Image img={require('../../img/langsmith_new.png')} />
## Logging LLM IO to Arize AI
## Arize AI
1. Set `success_callback: ["arize"]` on litellm config.yaml
@ -1309,7 +1309,7 @@ Expect to see your log on Langfuse
<Image img={require('../../img/langsmith_new.png')} />
## Logging LLM IO to Langtrace
## Langtrace
1. Set `success_callback: ["langtrace"]` on litellm config.yaml
@ -1351,7 +1351,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
'
```
## Logging LLM IO to Galileo
## Galileo
[BETA]
@ -1466,7 +1466,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
<Image img={require('../../img/openmeter_img_2.png')} />
## Logging Proxy Input/Output - DataDog
## DataDog
LiteLLM Supports logging to the following Datdog Integrations:
- `datadog` [Datadog Logs](https://docs.datadoghq.com/logs/)
@ -1543,7 +1543,7 @@ Expected output on Datadog
<Image img={require('../../img/dd_small1.png')} />
## Logging Proxy Input/Output - DynamoDB
## DynamoDB
We will use the `--config` to set
@ -1669,7 +1669,7 @@ Your logs should be available on DynamoDB
}
```
## Logging Proxy Input/Output - Sentry
## Sentry
If api calls fail (llm/database) you can log those to Sentry:
@ -1711,7 +1711,7 @@ Test Request
litellm --test
```
## Logging Proxy Input/Output Athina
## Athina
[Athina](https://athina.ai/) allows you to log LLM Input/Output for monitoring, analytics, and observability.

View file

@ -20,6 +20,10 @@ general_settings:
proxy_batch_write_at: 60 # Batch write spend updates every 60s
database_connection_pool_limit: 10 # limit the number of database connections to = MAX Number of DB Connections/Number of instances of litellm proxy (Around 10-20 is good number)
# OPTIONAL Best Practices
disable_spend_logs: True # turn off writing each transaction to the db. We recommend doing this is you don't need to see Usage on the LiteLLM UI and are tracking metrics via Prometheus
allow_requests_on_db_unavailable: True # Only USE when running LiteLLM on your VPC. Allow requests to still be processed even if the DB is unavailable. We recommend doing this if you're running LiteLLM on VPC that cannot be accessed from the public internet.
litellm_settings:
request_timeout: 600 # raise Timeout error if call takes longer than 600 seconds. Default value is 6000seconds if not set
set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
@ -86,7 +90,29 @@ Set `export LITELLM_MODE="PRODUCTION"`
This disables the load_dotenv() functionality, which will automatically load your environment credentials from the local `.env`.
## 5. Set LiteLLM Salt Key
## 5. If running LiteLLM on VPC, gracefully handle DB unavailability
This will allow LiteLLM to continue to process requests even if the DB is unavailable. This is better handling for DB unavailability.
**WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.**
```yaml
general_settings:
allow_requests_on_db_unavailable: True
```
## 6. Disable spend_logs if you're not using the LiteLLM UI
By default LiteLLM will write every request to the `LiteLLM_SpendLogs` table. This is used for viewing Usage on the LiteLLM UI.
If you're not viewing Usage on the LiteLLM UI (most users use Prometheus when this is disabled), you can disable spend_logs by setting `disable_spend_logs` to `True`.
```yaml
general_settings:
disable_spend_logs: True
```
## 7. Set LiteLLM Salt Key
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.

View file

@ -56,7 +56,7 @@ Possible values for `budget_duration`
| `budget_duration="1m"` | every 1 min |
| `budget_duration="1h"` | every 1 hour |
| `budget_duration="1d"` | every 1 day |
| `budget_duration="1mo"` | every 1 month |
| `budget_duration="30d"` | every 1 month |
### 2. Create a key for the `team`

View file

@ -30,11 +30,11 @@ This config would send langfuse logs to 2 different langfuse projects, based on
```yaml
litellm_settings:
default_team_settings:
- team_id: my-secret-project
- team_id: "dbe2f686-a686-4896-864a-4c3924458709"
success_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
- team_id: ishaans-secret-project
- team_id: "06ed1e01-3fa7-4b9e-95bc-f2e59b74f3a8"
success_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_2 # Project 2
langfuse_secret: os.environ/LANGFUSE_SECRET_2 # Project 2
@ -46,7 +46,7 @@ Now, when you [generate keys](./virtual_keys.md) for this team-id
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"team_id": "ishaans-secret-project"}'
-d '{"team_id": "06ed1e01-3fa7-4b9e-95bc-f2e59b74f3a8"}'
```
All requests made with these keys will log data to their team-specific logging. -->
@ -281,6 +281,51 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
}'
```
</TabItem>
<TabItem label="Langsmith" value="langsmith">
1. Create Virtual Key to log to a specific Langsmith Project
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"logging": [{
"callback_name": "langsmith", # "otel", "gcs_bucket"
"callback_type": "success", # "success", "failure", "success_and_failure"
"callback_vars": {
"langsmith_api_key": "os.environ/LANGSMITH_API_KEY", # API Key for Langsmith logging
"langsmith_project": "pr-brief-resemblance-72", # project name on langsmith
"langsmith_base_url": "https://api.smith.langchain.com"
}
}]
}
}'
```
2. Test it - `/chat/completions` request
Use the virtual key from step 3 to make a `/chat/completions` request
You should see your logs on your Langsmith project on a successful request
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-Fxq5XSyWKeXDKfPdqXZhPg" \
-d '{
"model": "fake-openai-endpoint",
"messages": [
{"role": "user", "content": "Hello, Claude"}
],
"user": "hello",
}'
```
</TabItem>
</Tabs>

View file

@ -205,6 +205,7 @@ const sidebars = {
"completion/prompt_caching",
"completion/audio",
"completion/vision",
"completion/predict_outputs",
"completion/prefix",
"completion/drop_params",
"completion/prompt_formatting",
@ -265,6 +266,7 @@ const sidebars = {
type: "category",
label: "Load Testing",
items: [
"benchmarks",
"load_test",
"load_test_advanced",
"load_test_sdk",

View file

@ -137,6 +137,8 @@ safe_memory_mode: bool = False
enable_azure_ad_token_refresh: Optional[bool] = False
### DEFAULT AZURE API VERSION ###
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
### DEFAULT WATSONX API VERSION ###
WATSONX_DEFAULT_API_VERSION = "2024-03-13"
### COHERE EMBEDDINGS DEFAULT TYPE ###
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
### GUARDRAILS ###
@ -282,7 +284,9 @@ priority_reservation: Optional[Dict[str, float]] = None
#### RELIABILITY ####
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
request_timeout: float = 6000 # time in seconds
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient"
)
module_level_client = HTTPHandler(timeout=request_timeout)
num_retries: Optional[int] = None # per model endpoint
max_fallbacks: Optional[int] = None
@ -371,6 +375,7 @@ open_ai_text_completion_models: List = []
cohere_models: List = []
cohere_chat_models: List = []
mistral_chat_models: List = []
text_completion_codestral_models: List = []
anthropic_models: List = []
empower_models: List = []
openrouter_models: List = []
@ -397,6 +402,19 @@ deepinfra_models: List = []
perplexity_models: List = []
watsonx_models: List = []
gemini_models: List = []
xai_models: List = []
deepseek_models: List = []
azure_ai_models: List = []
voyage_models: List = []
databricks_models: List = []
cloudflare_models: List = []
codestral_models: List = []
friendliai_models: List = []
palm_models: List = []
groq_models: List = []
azure_models: List = []
anyscale_models: List = []
cerebras_models: List = []
def add_known_models():
@ -473,6 +491,34 @@ def add_known_models():
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
if "-to-" not in key:
fireworks_ai_embedding_models.append(key)
elif value.get("litellm_provider") == "text-completion-codestral":
text_completion_codestral_models.append(key)
elif value.get("litellm_provider") == "xai":
xai_models.append(key)
elif value.get("litellm_provider") == "deepseek":
deepseek_models.append(key)
elif value.get("litellm_provider") == "azure_ai":
azure_ai_models.append(key)
elif value.get("litellm_provider") == "voyage":
voyage_models.append(key)
elif value.get("litellm_provider") == "databricks":
databricks_models.append(key)
elif value.get("litellm_provider") == "cloudflare":
cloudflare_models.append(key)
elif value.get("litellm_provider") == "codestral":
codestral_models.append(key)
elif value.get("litellm_provider") == "friendliai":
friendliai_models.append(key)
elif value.get("litellm_provider") == "palm":
palm_models.append(key)
elif value.get("litellm_provider") == "groq":
groq_models.append(key)
elif value.get("litellm_provider") == "azure":
azure_models.append(key)
elif value.get("litellm_provider") == "anyscale":
anyscale_models.append(key)
elif value.get("litellm_provider") == "cerebras":
cerebras_models.append(key)
add_known_models()
@ -527,7 +573,11 @@ openai_text_completion_compatible_providers: List = (
"hosted_vllm",
]
)
_openai_like_providers: List = [
"predibase",
"databricks",
"watsonx",
] # private helper. similar to openai but require some custom auth / endpoint handling, so can't use the openai sdk
# well supported replicate llms
replicate_models: List = [
# llama replicate supported LLMs
@ -714,6 +764,20 @@ model_list = (
+ vertex_language_models
+ watsonx_models
+ gemini_models
+ text_completion_codestral_models
+ xai_models
+ deepseek_models
+ azure_ai_models
+ voyage_models
+ databricks_models
+ cloudflare_models
+ codestral_models
+ friendliai_models
+ palm_models
+ groq_models
+ azure_models
+ anyscale_models
+ cerebras_models
)
@ -770,6 +834,7 @@ class LlmProviders(str, Enum):
FIREWORKS_AI = "fireworks_ai"
FRIENDLIAI = "friendliai"
WATSONX = "watsonx"
WATSONX_TEXT = "watsonx_text"
TRITON = "triton"
PREDIBASE = "predibase"
DATABRICKS = "databricks"
@ -786,6 +851,7 @@ provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
models_by_provider: dict = {
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,
"text-completion-openai": open_ai_text_completion_models,
"cohere": cohere_models + cohere_chat_models,
"cohere_chat": cohere_chat_models,
"anthropic": anthropic_models,
@ -809,6 +875,23 @@ models_by_provider: dict = {
"watsonx": watsonx_models,
"gemini": gemini_models,
"fireworks_ai": fireworks_ai_models + fireworks_ai_embedding_models,
"aleph_alpha": aleph_alpha_models,
"text-completion-codestral": text_completion_codestral_models,
"xai": xai_models,
"deepseek": deepseek_models,
"mistral": mistral_chat_models,
"azure_ai": azure_ai_models,
"voyage": voyage_models,
"databricks": databricks_models,
"cloudflare": cloudflare_models,
"codestral": codestral_models,
"nlp_cloud": nlp_cloud_models,
"friendliai": friendliai_models,
"palm": palm_models,
"groq": groq_models,
"azure": azure_models,
"anyscale": anyscale_models,
"cerebras": cerebras_models,
}
# mapping for those models which have larger equivalents
@ -881,7 +964,6 @@ from .utils import (
supports_system_messages,
get_litellm_params,
acreate,
get_model_list,
get_max_tokens,
get_model_info,
register_prompt_template,
@ -976,10 +1058,11 @@ from .llms.bedrock.common_utils import (
AmazonAnthropicClaude3Config,
AmazonCohereConfig,
AmazonLlamaConfig,
AmazonStabilityConfig,
AmazonMistralConfig,
AmazonBedrockGlobalConfig,
)
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
AmazonTitanMultimodalEmbeddingG1Config,
@ -1037,10 +1120,12 @@ from .llms.AzureOpenAI.azure import (
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
from .llms.perplexity.chat.transformation import PerplexityChatConfig
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
from .llms.watsonx import IBMWatsonXAIConfig
from .llms.watsonx.completion.handler import IBMWatsonXAIConfig
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
from .main import * # type: ignore
from .integrations import *
from .exceptions import (

View file

@ -241,7 +241,8 @@ class ServiceLogging(CustomLogger):
if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none()
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload
payload=payload,
error=error,
)
elif callback == "datadog":
await self.init_datadog_logger_if_none()

View file

@ -8,6 +8,7 @@ Has 4 methods:
- async_get_cache
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
@ -18,7 +19,7 @@ else:
Span = Any
class BaseCache:
class BaseCache(ABC):
def __init__(self, default_ttl: int = 60):
self.default_ttl = default_ttl
@ -37,6 +38,10 @@ class BaseCache:
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
@abstractmethod
async def async_set_cache_pipeline(self, cache_list, **kwargs):
pass
def get_cache(self, key, **kwargs):
raise NotImplementedError

View file

@ -233,19 +233,18 @@ class Cache:
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
def get_cache_key(self, *args, **kwargs) -> str:
def get_cache_key(self, **kwargs) -> str:
"""
Get the cache key for the given arguments.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated.
"""
cache_key = ""
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
if preset_cache_key is not None:
@ -521,7 +520,7 @@ class Cache:
return cached_response
return cached_result
def get_cache(self, *args, **kwargs):
def get_cache(self, **kwargs):
"""
Retrieves the cached result for the given arguments.
@ -533,13 +532,13 @@ class Cache:
The cached result if it exists, otherwise None.
"""
try: # never block execution
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
messages = kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
@ -553,29 +552,28 @@ class Cache:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
async def async_get_cache(self, *args, **kwargs):
async def async_get_cache(self, **kwargs):
"""
Async get cache implementation.
Used for embedding calls in async wrapper
"""
try: # never block execution
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(
cache_key, *args, **kwargs
)
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
@ -583,7 +581,7 @@ class Cache:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
def _add_cache_logic(self, result, *args, **kwargs):
def _add_cache_logic(self, result, **kwargs):
"""
Common implementation across sync + async add_cache functions
"""
@ -591,7 +589,7 @@ class Cache:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
if isinstance(result, BaseModel):
result = result.model_dump_json()
@ -613,7 +611,7 @@ class Cache:
except Exception as e:
raise e
def add_cache(self, result, *args, **kwargs):
def add_cache(self, result, **kwargs):
"""
Adds a result to the cache.
@ -625,41 +623,42 @@ class Cache:
None
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
result=result, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache(self, result, *args, **kwargs):
async def async_add_cache(self, result, **kwargs):
"""
Async implementation of add_cache
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
if self.type == "redis" and self.redis_flush_size is not None:
# high traffic - fill in results in memory and then flush
await self.batch_cache_write(result, *args, **kwargs)
await self.batch_cache_write(result, **kwargs)
else:
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
result=result, **kwargs
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache_pipeline(self, result, *args, **kwargs):
async def async_add_cache_pipeline(self, result, **kwargs):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
# set default ttl if not set
@ -668,29 +667,27 @@ class Cache:
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response,
*args,
**kwargs,
)
cache_list.append((cache_key, cached_data))
async_set_cache_pipeline = getattr(
self.cache, "async_set_cache_pipeline", None
)
if async_set_cache_pipeline:
await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
else:
tasks = []
for val in cache_list:
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# if async_set_cache_pipeline:
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# else:
# tasks = []
# for val in cache_list:
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
# await asyncio.gather(*tasks)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
def should_use_cache(self, *args, **kwargs):
def should_use_cache(self, **kwargs):
"""
Returns true if we should use the cache for LLM API calls
@ -708,10 +705,8 @@ class Cache:
return True
return False
async def batch_cache_write(self, result, *args, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
)
async def batch_cache_write(self, result, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self):

View file

@ -137,7 +137,7 @@ class LLMCachingHandler:
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
print_verbose("Checking Cache")
verbose_logger.debug("Checking Cache")
cached_result = await self._retrieve_from_cache(
call_type=call_type,
kwargs=kwargs,
@ -145,7 +145,7 @@ class LLMCachingHandler:
)
if cached_result is not None and not isinstance(cached_result, list):
print_verbose("Cache Hit!")
verbose_logger.debug("Cache Hit!")
cache_hit = True
end_time = datetime.datetime.now()
model, _, _, _ = litellm.get_llm_provider(
@ -215,6 +215,7 @@ class LLMCachingHandler:
final_embedding_cached_response=final_embedding_cached_response,
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
)
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
return CachingHandlerResponse(
cached_result=cached_result,
final_embedding_cached_response=final_embedding_cached_response,
@ -233,12 +234,19 @@ class LLMCachingHandler:
from litellm.utils import CustomStreamWrapper
args = args or ()
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
print_verbose("Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs)
cached_result = litellm.cache.get_cache(**new_kwargs)
if cached_result is not None:
if "detail" in cached_result:
# implies an error occurred
@ -475,14 +483,21 @@ class LLMCachingHandler:
if litellm.cache is None:
return None
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None
if call_type == CallTypes.aembedding.value and isinstance(
kwargs["input"], list
new_kwargs["input"], list
):
tasks = []
for idx, i in enumerate(kwargs["input"]):
for idx, i in enumerate(new_kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
**{**new_kwargs, "input": i}
)
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
cached_result = await asyncio.gather(*tasks)
@ -493,9 +508,9 @@ class LLMCachingHandler:
cached_result = None
else:
if litellm.cache._supports_async() is True:
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
cached_result = litellm.cache.get_cache(*args, **kwargs)
cached_result = litellm.cache.get_cache(**new_kwargs)
return cached_result
def _convert_cached_result_to_model_response(
@ -580,6 +595,7 @@ class LLMCachingHandler:
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
elif (
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
) and isinstance(cached_result, dict):
@ -603,6 +619,13 @@ class LLMCachingHandler:
response_type="audio_transcription",
hidden_params=hidden_params,
)
if (
hasattr(cached_result, "_hidden_params")
and cached_result._hidden_params is not None
and isinstance(cached_result._hidden_params, dict)
):
cached_result._hidden_params["cache_hit"] = True
return cached_result
def _convert_cached_stream_response(
@ -658,12 +681,19 @@ class LLMCachingHandler:
Raises:
None
"""
kwargs.update(convert_args_to_kwargs(result, original_function, kwargs, args))
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
original_function,
args,
)
)
if litellm.cache is None:
return
# [OPTIONAL] ADD TO CACHE
if self._should_store_result_in_cache(
original_function=original_function, kwargs=kwargs
original_function=original_function, kwargs=new_kwargs
):
if (
isinstance(result, litellm.ModelResponse)
@ -673,29 +703,29 @@ class LLMCachingHandler:
):
if (
isinstance(result, EmbeddingResponse)
and isinstance(kwargs["input"], list)
and isinstance(new_kwargs["input"], list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
asyncio.create_task(
litellm.cache.async_add_cache_pipeline(result, **kwargs)
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
)
elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread(
target=litellm.cache.add_cache,
args=(result,),
kwargs=kwargs,
kwargs=new_kwargs,
).start()
else:
asyncio.create_task(
litellm.cache.async_add_cache(
result.model_dump_json(), **kwargs
result.model_dump_json(), **new_kwargs
)
)
else:
asyncio.create_task(litellm.cache.async_add_cache(result, **kwargs))
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
def sync_set_cache(
self,
@ -706,16 +736,20 @@ class LLMCachingHandler:
"""
Sync internal method to add the result to the cache
"""
kwargs.update(
convert_args_to_kwargs(result, self.original_function, kwargs, args)
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
if litellm.cache is None:
return
if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=kwargs
original_function=self.original_function, kwargs=new_kwargs
):
litellm.cache.add_cache(result, **kwargs)
litellm.cache.add_cache(result, **new_kwargs)
return
@ -865,9 +899,7 @@ class LLMCachingHandler:
def convert_args_to_kwargs(
result: Any,
original_function: Callable,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
) -> Dict[str, Any]:
# Get the signature of the original function

View file

@ -24,7 +24,6 @@ class DiskCache(BaseCache):
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
print_verbose("DiskCache: set_cache")
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
@ -33,10 +32,10 @@ class DiskCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
async def async_set_cache_pipeline(self, cache_list, **kwargs):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
if "ttl" in kwargs:
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
else:
self.set_cache(key=cache_key, value=cache_value)

View file

@ -70,7 +70,7 @@ class DualCache(BaseCache):
self.redis_batch_cache_expiry = (
default_redis_batch_cache_expiry
or litellm.default_redis_batch_cache_expiry
or 5
or 10
)
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl
@ -314,7 +314,8 @@ class DualCache(BaseCache):
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
)
async def async_batch_set_cache(
# async_batch_set_cache
async def async_set_cache_pipeline(
self, cache_list: list, local_only: bool = False, **kwargs
):
"""

View file

@ -9,6 +9,7 @@ Has 4 methods:
"""
import ast
import asyncio
import json
from typing import Any
@ -422,3 +423,9 @@ class QdrantSemanticCache(BaseCache):
async def _collection_info(self):
return self.collection_info
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -404,7 +404,7 @@ class RedisCache(BaseCache):
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
return results
return None
except Exception as e:
## LOGGING ##
end_time = time.time()

View file

@ -9,6 +9,7 @@ Has 4 methods:
"""
import ast
import asyncio
import json
from typing import Any
@ -331,3 +332,9 @@ class RedisSemanticCache(BaseCache):
async def _index_info(self):
return await self.index.ainfo()
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -10,6 +10,7 @@ Has 4 methods:
"""
import ast
import asyncio
import json
from typing import Any, Optional
@ -153,3 +154,9 @@ class S3Cache(BaseCache):
async def disconnect(self):
pass
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -28,6 +28,9 @@ from litellm.llms.azure_ai.cost_calculator import (
from litellm.llms.AzureOpenAI.cost_calculation import (
cost_per_token as azure_openai_cost_per_token,
)
from litellm.llms.bedrock.image.cost_calculator import (
cost_calculator as bedrock_image_cost_calculator,
)
from litellm.llms.cohere.cost_calculator import (
cost_per_query as cohere_rerank_cost_per_query,
)
@ -521,12 +524,13 @@ def completion_cost( # noqa: PLR0915
custom_llm_provider=None,
region_name=None, # used for bedrock pricing
### IMAGE GEN ###
size=None,
size: Optional[str] = None,
quality=None,
n=None, # number of images
### CUSTOM PRICING ###
custom_cost_per_token: Optional[CostPerToken] = None,
custom_cost_per_second: Optional[float] = None,
optional_params: Optional[dict] = None,
) -> float:
"""
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
@ -667,7 +671,17 @@ def completion_cost( # noqa: PLR0915
# https://cloud.google.com/vertex-ai/generative-ai/pricing
# Vertex Charges Flat $0.20 per image
return 0.020
elif custom_llm_provider == "bedrock":
if isinstance(completion_response, ImageResponse):
return bedrock_image_cost_calculator(
model=model,
size=size,
image_response=completion_response,
optional_params=optional_params,
)
raise TypeError(
"completion_response must be of type ImageResponse for bedrock image cost calculation"
)
if size is None:
size = "1024-x-1024" # openai default
# fix size to match naming convention
@ -677,9 +691,9 @@ def completion_cost( # noqa: PLR0915
image_gen_model_name_with_quality = image_gen_model_name
if quality is not None:
image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}"
size = size.split("-x-")
height = int(size[0]) # if it's 1024-x-1024 vs. 1024x1024
width = int(size[1])
size_parts = size.split("-x-")
height = int(size_parts[0]) # if it's 1024-x-1024 vs. 1024x1024
width = int(size_parts[1])
verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}")
verbose_logger.debug(
f"image_gen_model_name_with_quality: {image_gen_model_name_with_quality}"
@ -844,6 +858,7 @@ def response_cost_calculator(
model=model,
call_type=call_type,
custom_llm_provider=custom_llm_provider,
optional_params=optional_params,
)
else:
if custom_pricing is True: # override defaults if custom pricing is set

View file

@ -423,7 +423,7 @@ class SlackAlerting(CustomBatchLogger):
latency_cache_keys = [(key, 0) for key in latency_keys]
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
await self.internal_usage_cache.async_batch_set_cache(
await self.internal_usage_cache.async_set_cache_pipeline(
cache_list=combined_metrics_cache_keys
)

View file

@ -23,7 +23,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.utils import get_formatted_prompt
from litellm.utils import get_formatted_prompt, print_verbose
global_braintrust_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
@ -229,6 +229,9 @@ class BraintrustLogger(CustomLogger):
request_data["metrics"] = metrics
try:
print_verbose(
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
)
global_braintrust_sync_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert",
json={"events": [request_data]},

View file

@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
self,
flush_lock: Optional[asyncio.Lock] = None,
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
**kwargs,
) -> None:
"""
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
"""
self.log_queue: List = []
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
self.last_flush_time = time.time()
self.flush_lock = flush_lock

View file

@ -1,3 +1,4 @@
import asyncio
import json
import os
import uuid
@ -10,10 +11,12 @@ from pydantic import BaseModel, Field
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import (
StandardCallbackDynamicParams,
StandardLoggingMetadata,
@ -26,10 +29,9 @@ else:
VertexBase = Any
class GCSLoggingConfig(TypedDict):
bucket_name: str
vertex_instance: VertexBase
path_service_account: str
IAM_AUTH_KEY = "IAM_AUTH"
GCS_DEFAULT_BATCH_SIZE = 2048
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
class GCSBucketLogger(GCSBucketBase):
@ -38,6 +40,21 @@ class GCSBucketLogger(GCSBucketBase):
super().__init__(bucket_name=bucket_name)
self.vertex_instances: Dict[str, VertexBase] = {}
# Init Batch logging settings
self.log_queue: List[GCSLogQueueItem] = []
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
self.flush_interval = int(
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
)
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(
flush_lock=self.flush_lock,
batch_size=self.batch_size,
flush_interval=self.flush_interval,
)
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
@ -57,54 +74,23 @@ class GCSBucketLogger(GCSBucketBase):
kwargs,
response_obj,
)
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
json_logged_payload = json.dumps(logging_payload, default=str)
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
# Modify the object_name to include the date-based folder
object_name = f"{current_date}/{response_obj['id']}"
try:
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
# Add to logging queue - this will be flushed periodically
self.log_queue.append(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
except httpx.HTTPStatusError as e:
raise Exception(f"GCS Bucket logging error: {e.response.text}")
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
)
try:
verbose_logger.debug(
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
@ -112,51 +98,138 @@ class GCSBucketLogger(GCSBucketBase):
response_obj,
)
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if logging_payload is None:
raise ValueError("standard_logging_object not found in kwargs")
_litellm_params = kwargs.get("litellm_params") or {}
metadata = _litellm_params.get("metadata") or {}
json_logged_payload = json.dumps(logging_payload, default=str)
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
# Modify the object_name to include the date-based folder
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
if "gcs_log_id" in metadata:
object_name = metadata["gcs_log_id"]
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
# Add to logging queue - this will be flushed periodically
self.log_queue.append(
GCSLogQueueItem(
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
)
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
except Exception as e:
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_send_batch(self):
"""
Process queued logs in batch - sends logs to GCS Bucket
GCS Bucket does not have a Batch endpoint to batch upload logs
Instead, we
- collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
- during async_send_batch, we make 1 POST request per log to GCS Bucket
"""
if not self.log_queue:
return
try:
for log_item in self.log_queue:
logging_payload = log_item["payload"]
kwargs = log_item["kwargs"]
response_obj = log_item.get("response_obj", None) or {}
gcs_logging_config: GCSLoggingConfig = (
await self.get_gcs_logging_config(kwargs)
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
object_name = self._get_object_name(
kwargs, logging_payload, response_obj
)
await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=logging_payload,
)
# Clear the queue after processing
self.log_queue.clear()
except Exception as e:
verbose_logger.exception(f"GCS Bucket batch logging error: {str(e)}")
def _get_object_name(
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
) -> str:
"""
Get the object name to use for the current payload
"""
current_date = datetime.now().strftime("%Y-%m-%d")
if logging_payload.get("error_str", None) is not None:
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
else:
object_name = f"{current_date}/{response_obj.get('id', '')}"
# used for testing
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
if "gcs_log_id" in _metadata:
object_name = _metadata["gcs_log_id"]
return object_name
def _handle_folders_in_bucket_name(
self,
bucket_name: str,
object_name: str,
) -> Tuple[str, str]:
"""
Handles when the user passes a bucket name with a folder postfix
Example:
- Bucket name: "my-bucket/my-folder/dev"
- Object name: "my-object"
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
"""
if "/" in bucket_name:
bucket_name, prefix = bucket_name.split("/", 1)
object_name = f"{prefix}/{object_name}"
return bucket_name, object_name
return bucket_name, object_name
async def _log_json_data_on_gcs(
self,
headers: Dict[str, str],
bucket_name: str,
object_name: str,
logging_payload: StandardLoggingPayload,
):
"""
Helper function to make POST request to GCS Bucket in the specified bucket.
"""
json_logged_payload = json.dumps(logging_payload, default=str)
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
async def get_gcs_logging_config(
self, kwargs: Optional[Dict[str, Any]] = {}
) -> GCSLoggingConfig:
@ -173,7 +246,7 @@ class GCSBucketLogger(GCSBucketBase):
)
bucket_name: str
path_service_account: str
path_service_account: Optional[str]
if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug(
@ -193,10 +266,6 @@ class GCSBucketLogger(GCSBucketBase):
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
if _path_service_account is None:
raise ValueError(
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
)
bucket_name = _bucket_name
path_service_account = _path_service_account
vertex_instance = await self.get_or_create_vertex_instance(
@ -208,10 +277,6 @@ class GCSBucketLogger(GCSBucketBase):
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
if self.path_service_account_json is None:
raise ValueError(
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
)
bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance(
@ -224,7 +289,9 @@ class GCSBucketLogger(GCSBucketBase):
path_service_account=path_service_account,
)
async def get_or_create_vertex_instance(self, credentials: str) -> VertexBase:
async def get_or_create_vertex_instance(
self, credentials: Optional[str]
) -> VertexBase:
"""
This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
@ -233,15 +300,27 @@ class GCSBucketLogger(GCSBucketBase):
VertexBase,
)
if credentials not in self.vertex_instances:
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
if _in_memory_key not in self.vertex_instances:
vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async(
credentials=credentials,
project_id=None,
custom_llm_provider="vertex_ai",
)
self.vertex_instances[credentials] = vertex_instance
return self.vertex_instances[credentials]
self.vertex_instances[_in_memory_key] = vertex_instance
return self.vertex_instances[_in_memory_key]
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
"""
Returns key to use for caching the Vertex instance in-memory.
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
- If a credentials string is provided, it is used as the key.
- If no credentials string is provided, "IAM_AUTH" is used as the key.
"""
return credentials or IAM_AUTH_KEY
async def download_gcs_object(self, object_name: str, **kwargs):
"""
@ -258,6 +337,11 @@ class GCSBucketLogger(GCSBucketBase):
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
# Send the GET request to download the object
@ -293,6 +377,11 @@ class GCSBucketLogger(GCSBucketBase):
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
# Send the DELETE request to delete the object

View file

@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
@ -21,8 +21,8 @@ else:
VertexBase = Any
class GCSBucketBase(CustomLogger):
def __init__(self, bucket_name: Optional[str] = None) -> None:
class GCSBucketBase(CustomBatchLogger):
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
@ -30,10 +30,11 @@ class GCSBucketBase(CustomLogger):
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
self.path_service_account_json: Optional[str] = _path_service_account
self.BUCKET_NAME: Optional[str] = _bucket_name
super().__init__(**kwargs)
async def construct_request_headers(
self,
service_account_json: str,
service_account_json: Optional[str],
vertex_instance: Optional[VertexBase] = None,
) -> Dict[str, str]:
from litellm import vertex_chat_completion

View file

@ -1,10 +1,11 @@
#### What this does ####
# On success, logs events to Langfuse
import copy
import inspect
import os
import traceback
from typing import TYPE_CHECKING, Any, Dict, Optional
import types
from collections.abc import MutableMapping, MutableSequence, MutableSet
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
from packaging.version import Version
from pydantic import BaseModel
@ -14,7 +15,7 @@ from litellm._logging import verbose_logger
from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.langfuse import *
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
@ -355,6 +356,73 @@ class LangFuseLogger:
)
)
def is_base_type(self, value: Any) -> bool:
# Check if the value is of a base type
base_types = (int, float, str, bool, list, dict, tuple)
return isinstance(value, base_types)
def _prepare_metadata(self, metadata: Optional[dict]) -> Any:
try:
if metadata is None:
return None
# Filter out function types from the metadata
sanitized_metadata = {k: v for k, v in metadata.items() if not callable(v)}
return copy.deepcopy(sanitized_metadata)
except Exception as e:
verbose_logger.debug(f"Langfuse Layer Error - {e}, metadata: {metadata}")
new_metadata: Dict[str, Any] = {}
# if metadata is not a MutableMapping, return an empty dict since we can't call items() on it
if not isinstance(metadata, MutableMapping):
verbose_logger.debug(
"Langfuse Layer Logging - metadata is not a MutableMapping, returning empty dict"
)
return new_metadata
for key, value in metadata.items():
try:
if isinstance(value, MutableMapping):
new_metadata[key] = self._prepare_metadata(cast(dict, value))
elif isinstance(value, MutableSequence):
# For lists or other mutable sequences
new_metadata[key] = list(
(
self._prepare_metadata(cast(dict, v))
if isinstance(v, MutableMapping)
else copy.deepcopy(v)
)
for v in value
)
elif isinstance(value, MutableSet):
# For sets specifically, create a new set by passing an iterable
new_metadata[key] = set(
(
self._prepare_metadata(cast(dict, v))
if isinstance(v, MutableMapping)
else copy.deepcopy(v)
)
for v in value
)
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump()
elif self.is_base_type(value):
new_metadata[key] = value
else:
verbose_logger.debug(
f"Langfuse Layer Error - Unsupported metadata type: {type(value)} for key: {key}"
)
continue
except (TypeError, copy.Error):
verbose_logger.debug(
f"Langfuse Layer Error - Couldn't copy metadata key: {key}, type of key: {type(key)}, type of value: {type(value)} - {traceback.format_exc()}"
)
return new_metadata
def _log_langfuse_v2( # noqa: PLR0915
self,
user_id,
@ -373,40 +441,19 @@ class LangFuseLogger:
) -> tuple:
import langfuse
print_verbose("Langfuse Layer Logging - logging to langfuse v2")
try:
tags = []
try:
optional_params.pop("metadata")
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except Exception:
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, dict)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = copy.deepcopy(value)
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump()
metadata = new_metadata
metadata = self._prepare_metadata(metadata)
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
supports_completion_start_time = Version(
langfuse.version.__version__
) >= Version("2.7.3")
langfuse_version = Version(langfuse.version.__version__)
print_verbose("Langfuse Layer Logging - logging to langfuse v2 ")
supports_tags = langfuse_version >= Version("2.6.3")
supports_prompt = langfuse_version >= Version("2.7.3")
supports_costs = langfuse_version >= Version("2.7.3")
supports_completion_start_time = langfuse_version >= Version("2.7.3")
if supports_tags:
metadata_tags = metadata.pop("tags", [])
tags = metadata_tags
tags = metadata.pop("tags", []) if supports_tags else []
# Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion

View file

@ -23,34 +23,8 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.utils import StandardLoggingPayload
class LangsmithInputs(BaseModel):
model: Optional[str] = None
messages: Optional[List[Any]] = None
stream: Optional[bool] = None
call_type: Optional[str] = None
litellm_call_id: Optional[str] = None
completion_start_time: Optional[datetime] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
custom_llm_provider: Optional[str] = None
input: Optional[List[Any]] = None
log_event_type: Optional[str] = None
original_response: Optional[Any] = None
response_cost: Optional[float] = None
# LiteLLM Virtual Key specific fields
user_api_key: Optional[str] = None
user_api_key_user_id: Optional[str] = None
user_api_key_team_alias: Optional[str] = None
class LangsmithCredentialsObject(TypedDict):
LANGSMITH_API_KEY: str
LANGSMITH_PROJECT: str
LANGSMITH_BASE_URL: str
from litellm.types.integrations.langsmith import *
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
def is_serializable(value):
@ -93,15 +67,16 @@ class LangsmithLogger(CustomBatchLogger):
)
if _batch_size:
self.batch_size = int(_batch_size)
self.log_queue: List[LangsmithQueueObject] = []
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
def get_credentials_from_env(
self,
langsmith_api_key: Optional[str],
langsmith_project: Optional[str],
langsmith_base_url: Optional[str],
langsmith_api_key: Optional[str] = None,
langsmith_project: Optional[str] = None,
langsmith_base_url: Optional[str] = None,
) -> LangsmithCredentialsObject:
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
@ -132,42 +107,19 @@ class LangsmithLogger(CustomBatchLogger):
LANGSMITH_PROJECT=_credentials_project,
)
def _prepare_log_data( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
def _prepare_log_data(
self,
kwargs,
response_obj,
start_time,
end_time,
credentials: LangsmithCredentialsObject,
):
import json
from datetime import datetime as dt
try:
_litellm_params = kwargs.get("litellm_params", {}) or {}
metadata = _litellm_params.get("metadata", {}) or {}
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = value
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump_json()
elif isinstance(value, dict):
for k, v in value.items():
if isinstance(v, dt):
value[k] = v.isoformat()
new_metadata[key] = value
metadata = new_metadata
kwargs["user_api_key"] = metadata.get("user_api_key", None)
kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None)
kwargs["user_api_key_team_alias"] = metadata.get(
"user_api_key_team_alias", None
)
project_name = metadata.get(
"project_name", self.default_credentials["LANGSMITH_PROJECT"]
"project_name", credentials["LANGSMITH_PROJECT"]
)
run_name = metadata.get("run_name", self.langsmith_default_run_name)
run_id = metadata.get("id", None)
@ -175,16 +127,10 @@ class LangsmithLogger(CustomBatchLogger):
trace_id = metadata.get("trace_id", None)
session_id = metadata.get("session_id", None)
dotted_order = metadata.get("dotted_order", None)
tags = metadata.get("tags", []) or []
verbose_logger.debug(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
# filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs
# logged_kwargs = LangsmithInputs(**kwargs)
# kwargs = logged_kwargs.model_dump()
# new_kwargs = {}
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
@ -193,7 +139,6 @@ class LangsmithLogger(CustomBatchLogger):
if payload is None:
raise Exception("Error logging request payload. Payload=none.")
new_kwargs = payload
metadata = payload[
"metadata"
] # ensure logged metadata is json serializable
@ -201,12 +146,12 @@ class LangsmithLogger(CustomBatchLogger):
data = {
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": new_kwargs,
"outputs": new_kwargs["response"],
"inputs": payload,
"outputs": payload["response"],
"session_name": project_name,
"start_time": new_kwargs["startTime"],
"end_time": new_kwargs["endTime"],
"tags": tags,
"start_time": payload["startTime"],
"end_time": payload["endTime"],
"tags": payload["request_tags"],
"extra": metadata,
}
@ -243,37 +188,6 @@ class LangsmithLogger(CustomBatchLogger):
except Exception:
raise
def _send_batch(self):
if not self.log_queue:
return
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
url = f"{langsmith_api_base}/runs/batch"
headers = {"x-api-key": langsmith_api_key}
try:
response = requests.post(
url=url,
json=self.log_queue,
headers=headers,
)
if response.status_code >= 300:
verbose_logger.error(
f"Langsmith Error: {response.status_code} - {response.text}"
)
else:
verbose_logger.debug(
f"Batch of {len(self.log_queue)} runs successfully created"
)
self.log_queue.clear()
except Exception:
verbose_logger.exception("Langsmith Layer Error - Error sending batch.")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = (
@ -295,8 +209,20 @@ class LangsmithLogger(CustomBatchLogger):
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
data = self._prepare_log_data(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
credentials=credentials,
)
self.log_queue.append(
LangsmithQueueObject(
data=data,
credentials=credentials,
)
)
verbose_logger.debug(
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
)
@ -323,8 +249,20 @@ class LangsmithLogger(CustomBatchLogger):
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
data = self._prepare_log_data(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
credentials=credentials,
)
self.log_queue.append(
LangsmithQueueObject(
data=data,
credentials=credentials,
)
)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
@ -349,8 +287,20 @@ class LangsmithLogger(CustomBatchLogger):
return # Skip logging
verbose_logger.info("Langsmith Failure Event Logging!")
try:
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
data = self._prepare_log_data(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
credentials=credentials,
)
self.log_queue.append(
LangsmithQueueObject(
data=data,
credentials=credentials,
)
)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
@ -365,31 +315,58 @@ class LangsmithLogger(CustomBatchLogger):
async def async_send_batch(self):
"""
sends runs to /batch endpoint
Handles sending batches of runs to Langsmith
Sends runs from self.log_queue
self.log_queue contains LangsmithQueueObjects
Each LangsmithQueueObject has the following:
- "credentials" - credentials to use for the request (langsmith_api_key, langsmith_project, langsmith_base_url)
- "data" - data to log on to langsmith for the request
This function
- groups the queue objects by credentials
- loops through each unique credentials and sends batches to Langsmith
This was added to support key/team based logging on langsmith
"""
if not self.log_queue:
return
batch_groups = self._group_batches_by_credentials()
for batch_group in batch_groups.values():
await self._log_batch_on_langsmith(
credentials=batch_group.credentials,
queue_objects=batch_group.queue_objects,
)
async def _log_batch_on_langsmith(
self,
credentials: LangsmithCredentialsObject,
queue_objects: List[LangsmithQueueObject],
):
"""
Logs a batch of runs to Langsmith
sends runs to /batch endpoint for the given credentials
Args:
credentials: LangsmithCredentialsObject
queue_objects: List[LangsmithQueueObject]
Returns: None
Raises: Does not raise an exception, will only verbose_logger.exception()
"""
if not self.log_queue:
return
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
langsmith_api_base = credentials["LANGSMITH_BASE_URL"]
langsmith_api_key = credentials["LANGSMITH_API_KEY"]
url = f"{langsmith_api_base}/runs/batch"
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
headers = {"x-api-key": langsmith_api_key}
elements_to_log = [queue_object["data"] for queue_object in queue_objects]
try:
response = await self.async_httpx_client.post(
url=url,
json={
"post": self.log_queue,
},
json={"post": elements_to_log},
headers=headers,
)
response.raise_for_status()
@ -411,6 +388,74 @@ class LangsmithLogger(CustomBatchLogger):
f"Langsmith Layer Error - {traceback.format_exc()}"
)
def _group_batches_by_credentials(self) -> Dict[CredentialsKey, BatchGroup]:
"""Groups queue objects by credentials using a proper key structure"""
log_queue_by_credentials: Dict[CredentialsKey, BatchGroup] = {}
for queue_object in self.log_queue:
credentials = queue_object["credentials"]
key = CredentialsKey(
api_key=credentials["LANGSMITH_API_KEY"],
project=credentials["LANGSMITH_PROJECT"],
base_url=credentials["LANGSMITH_BASE_URL"],
)
if key not in log_queue_by_credentials:
log_queue_by_credentials[key] = BatchGroup(
credentials=credentials, queue_objects=[]
)
log_queue_by_credentials[key].queue_objects.append(queue_object)
return log_queue_by_credentials
def _get_credentials_to_use_for_request(
self, kwargs: Dict[str, Any]
) -> LangsmithCredentialsObject:
"""
Handles key/team based logging
If standard_callback_dynamic_params are provided, use those credentials.
Otherwise, use the default credentials.
"""
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
if standard_callback_dynamic_params is not None:
credentials = self.get_credentials_from_env(
langsmith_api_key=standard_callback_dynamic_params.get(
"langsmith_api_key", None
),
langsmith_project=standard_callback_dynamic_params.get(
"langsmith_project", None
),
langsmith_base_url=standard_callback_dynamic_params.get(
"langsmith_base_url", None
),
)
else:
credentials = self.default_credentials
return credentials
def _send_batch(self):
"""Calls async_send_batch in an event loop"""
if not self.log_queue:
return
try:
# Try to get the existing event loop
loop = asyncio.get_event_loop()
if loop.is_running():
# If we're already in an event loop, create a task
asyncio.create_task(self.async_send_batch())
else:
# If no event loop is running, run the coroutine directly
loop.run_until_complete(self.async_send_batch())
except RuntimeError:
# If we can't get an event loop, create a new one
asyncio.run(self.async_send_batch())
def get_run_by_id(self, run_id):
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]

View file

@ -2,20 +2,23 @@ import os
from dataclasses import dataclass
from datetime import datetime
from functools import wraps
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import (
ChatCompletionMessageToolCall,
EmbeddingResponse,
Function,
ImageResponse,
ModelResponse,
StandardLoggingPayload,
)
if TYPE_CHECKING:
from opentelemetry.sdk.trace.export import SpanExporter as _SpanExporter
from opentelemetry.trace import Span as _Span
from litellm.proxy._types import (
@ -24,10 +27,12 @@ if TYPE_CHECKING:
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span
SpanExporter = _SpanExporter
UserAPIKeyAuth = _UserAPIKeyAuth
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
else:
Span = Any
SpanExporter = Any
UserAPIKeyAuth = Any
ManagementEndpointLoggingPayload = Any
@ -44,7 +49,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
@dataclass
class OpenTelemetryConfig:
from opentelemetry.sdk.trace.export import SpanExporter
exporter: Union[str, SpanExporter] = "console"
endpoint: Optional[str] = None
@ -77,7 +81,7 @@ class OpenTelemetryConfig:
class OpenTelemetry(CustomLogger):
def __init__(
self,
config: OpenTelemetryConfig = OpenTelemetryConfig.from_env(),
config: Optional[OpenTelemetryConfig] = None,
callback_name: Optional[str] = None,
**kwargs,
):
@ -85,6 +89,9 @@ class OpenTelemetry(CustomLogger):
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
if config is None:
config = OpenTelemetryConfig.from_env()
self.config = config
self.OTEL_EXPORTER = self.config.exporter
self.OTEL_ENDPOINT = self.config.endpoint
@ -281,21 +288,6 @@ class OpenTelemetry(CustomLogger):
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
parent_otel_span = user_api_key_dict.parent_otel_span
if parent_otel_span is not None:
parent_otel_span.set_status(Status(StatusCode.OK))
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
@ -334,8 +326,8 @@ class OpenTelemetry(CustomLogger):
span.end(end_time=self._to_ns(end_time))
# if parent_otel_span is not None:
# parent_otel_span.end(end_time=self._to_ns(datetime.now()))
if parent_otel_span is not None:
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
@ -413,6 +405,28 @@ class OpenTelemetry(CustomLogger):
except Exception:
return ""
@staticmethod
def _tool_calls_kv_pair(
tool_calls: List[ChatCompletionMessageToolCall],
) -> Dict[str, Any]:
from litellm.proxy._types import SpanAttributes
kv_pairs: Dict[str, Any] = {}
for idx, tool_call in enumerate(tool_calls):
_function = tool_call.get("function")
if not _function:
continue
keys = Function.__annotations__.keys()
for key in keys:
_value = _function.get(key)
if _value:
kv_pairs[
f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.function_call.{key}"
] = _value
return kv_pairs
def set_attributes( # noqa: PLR0915
self, span: Span, kwargs, response_obj: Optional[Any]
):
@ -607,18 +621,13 @@ class OpenTelemetry(CustomLogger):
message = choice.get("message")
tool_calls = message.get("tool_calls")
if tool_calls:
self.safe_set_attribute(
span=span,
key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.function_call.name",
value=tool_calls[0].get("function").get("name"),
)
self.safe_set_attribute(
span=span,
key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.function_call.arguments",
value=tool_calls[0]
.get("function")
.get("arguments"),
)
kv_pairs = OpenTelemetry._tool_calls_kv_pair(tool_calls) # type: ignore
for key, value in kv_pairs.items():
self.safe_set_attribute(
span=span,
key=key,
value=value,
)
except Exception as e:
verbose_logger.exception(
@ -715,10 +724,10 @@ class OpenTelemetry(CustomLogger):
TraceContextTextMapPropagator,
)
verbose_logger.debug("OpenTelemetry: GOT A TRACEPARENT {}".format(_traceparent))
propagator = TraceContextTextMapPropagator()
_parent_context = propagator.extract(carrier={"traceparent": _traceparent})
verbose_logger.debug("OpenTelemetry: PARENT CONTEXT {}".format(_parent_context))
carrier = {"traceparent": _traceparent}
_parent_context = propagator.extract(carrier=carrier)
return _parent_context
def _get_span_context(self, kwargs):

View file

@ -9,6 +9,7 @@ import subprocess
import sys
import traceback
import uuid
from typing import List, Optional, Union
import dotenv
import requests # type: ignore
@ -51,7 +52,9 @@ class PrometheusServicesLogger:
for service in self.services:
histogram = self.create_histogram(service, type_of_request="latency")
counter_failed_request = self.create_counter(
service, type_of_request="failed_requests"
service,
type_of_request="failed_requests",
additional_labels=["error_class", "function_name"],
)
counter_total_requests = self.create_counter(
service, type_of_request="total_requests"
@ -99,7 +102,12 @@ class PrometheusServicesLogger:
buckets=LATENCY_BUCKETS,
)
def create_counter(self, service: str, type_of_request: str):
def create_counter(
self,
service: str,
type_of_request: str,
additional_labels: Optional[List[str]] = None,
):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
@ -107,7 +115,7 @@ class PrometheusServicesLogger:
return self.Counter(
metric_name,
"Total {} for {} service".format(type_of_request, service),
labelnames=[service],
labelnames=[service] + (additional_labels or []),
)
def observe_histogram(
@ -125,10 +133,14 @@ class PrometheusServicesLogger:
counter,
labels: str,
amount: float,
additional_labels: Optional[List[str]] = [],
):
assert isinstance(counter, self.Counter)
counter.labels(labels).inc(amount)
if additional_labels:
counter.labels(labels, *additional_labels).inc(amount)
else:
counter.labels(labels).inc(amount)
def service_success_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
@ -187,16 +199,25 @@ class PrometheusServicesLogger:
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
async def async_service_failure_hook(
self,
payload: ServiceLoggerPayload,
error: Union[str, Exception],
):
if self.mock_testing:
self.mock_testing_failure_calls += 1
error_class = error.__class__.__name__
function_name = payload.call_type
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
# increment both failed and total requests
if isinstance(obj, self.Counter):
self.increment_counter(
counter=obj,
labels=payload.service.value,
# log additional_labels=["error_class", "function_name"], used for debugging what's going wrong with the DB
additional_labels=[error_class, function_name],
amount=1, # LOG ERROR COUNT TO PROMETHEUS
)

View file

@ -0,0 +1,11 @@
## Folder Contents
This folder contains general-purpose utilities that are used in multiple places in the codebase.
Core files:
- `streaming_handler.py`: The core streaming logic + streaming related helper utils
- `core_helpers.py`: code used in `types/` - e.g. `map_finish_reason`.
- `exception_mapping_utils.py`: utils for mapping exceptions to openai-compatible error types.
- `default_encoding.py`: code for loading the default encoding (tiktoken)
- `get_llm_provider_logic.py`: code for inferring the LLM provider from a given model name.

View file

@ -3,6 +3,8 @@
import os
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
import httpx
from litellm._logging import verbose_logger
if TYPE_CHECKING:
@ -99,3 +101,28 @@ def _get_parent_otel_span_from_kwargs(
"Error in _get_parent_otel_span_from_kwargs: " + str(e)
)
return None
def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> dict:
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
openai_headers = {}
processed_headers = {}
additional_headers = {}
for k, v in response_headers.items():
if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers
openai_headers[k] = v
if k.startswith(
"llm_provider-"
): # return raw provider headers (incl. openai-compatible ones)
processed_headers[k] = v
else:
additional_headers["{}-{}".format("llm_provider", k)] = v
additional_headers = {
**openai_headers,
**processed_headers,
**additional_headers,
}
return additional_headers

View file

@ -0,0 +1,21 @@
import os
import litellm
try:
# New and recommended way to access resources
from importlib import resources
filename = str(resources.files(litellm).joinpath("llms/tokenizers"))
except (ImportError, AttributeError):
# Old way to access resources, which setuptools deprecated some time ago
import pkg_resources # type: ignore
filename = pkg_resources.resource_filename(__name__, "llms/tokenizers")
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
"CUSTOM_TIKTOKEN_CACHE_DIR", filename
) # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")

View file

@ -612,19 +612,7 @@ def exception_type( # type: ignore # noqa: PLR0915
url="https://api.replicate.com/v1/deployments",
),
)
elif custom_llm_provider == "watsonx":
if "token_quota_reached" in error_str:
exception_mapping_worked = True
raise RateLimitError(
message=f"WatsonxException: Rate Limit Errror - {error_str}",
llm_provider="watsonx",
model=model,
response=original_exception.response,
)
elif (
custom_llm_provider == "predibase"
or custom_llm_provider == "databricks"
):
elif custom_llm_provider in litellm._openai_like_providers:
if "authorization denied for" in error_str:
exception_mapping_worked = True
@ -646,6 +634,24 @@ def exception_type( # type: ignore # noqa: PLR0915
response=original_exception.response,
litellm_debug_info=extra_information,
)
elif "token_quota_reached" in error_str:
exception_mapping_worked = True
raise RateLimitError(
message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
)
elif (
"The server received an invalid response from an upstream server."
in error_str
):
exception_mapping_worked = True
raise litellm.InternalServerError(
message=f"{custom_llm_provider}Exception - {original_exception.message}",
llm_provider=custom_llm_provider,
model=model,
)
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500:
exception_mapping_worked = True

View file

@ -226,7 +226,7 @@ def get_llm_provider( # noqa: PLR0915
## openrouter
elif model in litellm.openrouter_models:
custom_llm_provider = "openrouter"
## openrouter
## maritalk
elif model in litellm.maritalk_models:
custom_llm_provider = "maritalk"
## vertex - text + chat + language (gemini) models

View file

@ -0,0 +1,288 @@
from typing import Literal, Optional
import litellm
from litellm.exceptions import BadRequestError
def get_supported_openai_params( # noqa: PLR0915
model: str,
custom_llm_provider: Optional[str] = None,
request_type: Literal["chat_completion", "embeddings"] = "chat_completion",
) -> Optional[list]:
"""
Returns the supported openai params for a given model + provider
Example:
```
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
```
Returns:
- List if custom_llm_provider is mapped
- None if unmapped
"""
if not custom_llm_provider:
try:
custom_llm_provider = litellm.get_llm_provider(model=model)[1]
except BadRequestError:
return None
if custom_llm_provider == "bedrock":
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params()
elif custom_llm_provider == "ollama_chat":
return litellm.OllamaChatConfig().get_supported_openai_params()
elif custom_llm_provider == "anthropic":
return litellm.AnthropicConfig().get_supported_openai_params()
elif custom_llm_provider == "fireworks_ai":
if request_type == "embeddings":
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
model=model
)
else:
return litellm.FireworksAIConfig().get_supported_openai_params()
elif custom_llm_provider == "nvidia_nim":
if request_type == "chat_completion":
return litellm.nvidiaNimConfig.get_supported_openai_params(model=model)
elif request_type == "embeddings":
return litellm.nvidiaNimEmbeddingConfig.get_supported_openai_params()
elif custom_llm_provider == "cerebras":
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "xai":
return litellm.XAIChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ai21_chat":
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "volcengine":
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "groq":
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "hosted_vllm":
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepseek":
return [
# https://platform.deepseek.com/api-docs/api/create-chat-completion
"frequency_penalty",
"max_tokens",
"presence_penalty",
"response_format",
"stop",
"stream",
"temperature",
"top_p",
"logprobs",
"top_logprobs",
"tools",
"tool_choice",
]
elif custom_llm_provider == "cohere":
return [
"stream",
"temperature",
"max_tokens",
"logit_bias",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"extra_headers",
]
elif custom_llm_provider == "cohere_chat":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
"seed",
"extra_headers",
]
elif custom_llm_provider == "maritalk":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"presence_penalty",
"stop",
]
elif custom_llm_provider == "openai":
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "azure":
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
return litellm.AzureOpenAIO1Config().get_supported_openai_params(
model=model
)
else:
return litellm.AzureOpenAIConfig().get_supported_openai_params()
elif custom_llm_provider == "openrouter":
return [
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"repetition_penalty",
"seed",
"max_tokens",
"logit_bias",
"logprobs",
"top_logprobs",
"response_format",
"stop",
"tools",
"tool_choice",
]
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
# mistal and codestral api have the exact same params
if request_type == "chat_completion":
return litellm.MistralConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "text-completion-codestral":
return litellm.MistralTextCompletionConfig().get_supported_openai_params()
elif custom_llm_provider == "replicate":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"seed",
"tools",
"tool_choice",
"functions",
"function_call",
]
elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"tools",
"tool_choice",
"response_format",
]
elif custom_llm_provider == "ai21":
return [
"stream",
"n",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"presence_penalty",
]
elif custom_llm_provider == "databricks":
if request_type == "chat_completion":
return litellm.DatabricksConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
elif custom_llm_provider == "vertex_ai":
if request_type == "chat_completion":
if model.startswith("meta/"):
return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params()
if model.startswith("codestral"):
return (
litellm.MistralTextCompletionConfig().get_supported_openai_params()
)
if model.startswith("claude"):
return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "vertex_ai_beta":
if request_type == "chat_completion":
return litellm.VertexGeminiConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha":
return [
"max_tokens",
"stream",
"top_p",
"temperature",
"presence_penalty",
"frequency_penalty",
"n",
"stop",
]
elif custom_llm_provider == "cloudflare":
return ["max_tokens", "stream"]
elif custom_llm_provider == "nlp_cloud":
return [
"max_tokens",
"stream",
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"n",
"stop",
]
elif custom_llm_provider == "petals":
return ["max_tokens", "temperature", "top_p", "stream"]
elif custom_llm_provider == "deepinfra":
return litellm.DeepInfraConfig().get_supported_openai_params()
elif custom_llm_provider == "perplexity":
return [
"temperature",
"top_p",
"stream",
"max_tokens",
"presence_penalty",
"frequency_penalty",
]
elif custom_llm_provider == "anyscale":
return [
"temperature",
"top_p",
"stream",
"max_tokens",
"stop",
"frequency_penalty",
"presence_penalty",
]
elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "custom_openai" or "text-completion-openai":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
return None

View file

@ -2474,6 +2474,14 @@ class StandardLoggingPayloadSetup:
) -> Tuple[float, float, float]:
"""
Convert datetime objects to floats
Args:
start_time: Union[dt_object, float]
end_time: Union[dt_object, float]
completion_start_time: Union[dt_object, float]
Returns:
Tuple[float, float, float]: A tuple containing the start time, end time, and completion start time as floats.
"""
if isinstance(start_time, datetime.datetime):
@ -2534,13 +2542,10 @@ class StandardLoggingPayloadSetup:
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
clean_metadata = StandardLoggingMetadata(
**{ # type: ignore
key: metadata[key]
for key in StandardLoggingMetadata.__annotations__.keys()
if key in metadata
}
)
supported_keys = StandardLoggingMetadata.__annotations__.keys()
for key in supported_keys:
if key in metadata:
clean_metadata[key] = metadata[key] # type: ignore
if metadata.get("user_api_key") is not None:
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
@ -2769,11 +2774,6 @@ def get_standard_logging_object_payload(
metadata=metadata
)
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = None
saved_cache_cost: float = 0.0
if cache_hit is True:
@ -2815,7 +2815,7 @@ def get_standard_logging_object_payload(
completionStartTime=completion_start_time_float,
model=kwargs.get("model", "") or "",
metadata=clean_metadata,
cache_key=cache_key,
cache_key=clean_hidden_params["cache_key"],
response_cost=response_cost,
total_tokens=usage.total_tokens,
prompt_tokens=usage.prompt_tokens,

View file

@ -14,11 +14,17 @@ from litellm.types.utils import (
Delta,
EmbeddingResponse,
Function,
HiddenParams,
ImageResponse,
)
from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.types.utils import (
Message,
ModelResponse,
RerankResponse,
StreamingChoices,
TextChoices,
TextCompletionResponse,
TranscriptionResponse,
Usage,
)
@ -235,6 +241,77 @@ class LiteLLMResponseObjectHandler:
model_response_object = ImageResponse(**model_response_dict)
return model_response_object
@staticmethod
def convert_chat_to_text_completion(
response: ModelResponse,
text_completion_response: TextCompletionResponse,
custom_llm_provider: Optional[str] = None,
) -> TextCompletionResponse:
"""
Converts a chat completion response to a text completion response format.
Note: This is used for huggingface. For OpenAI / Azure Text the providers files directly return TextCompletionResponse which we then send to user
Args:
response (ModelResponse): The chat completion response to convert
Returns:
TextCompletionResponse: The converted text completion response
Example:
chat_response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hi"}])
text_response = convert_chat_to_text_completion(chat_response)
"""
transformed_logprobs = LiteLLMResponseObjectHandler._convert_provider_response_logprobs_to_text_completion_logprobs(
response=response,
custom_llm_provider=custom_llm_provider,
)
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None)
text_completion_response["model"] = response.get("model", None)
choices_list: List[TextChoices] = []
# Convert each choice to TextChoices
for choice in response["choices"]:
text_choices = TextChoices()
text_choices["text"] = choice["message"]["content"]
text_choices["index"] = choice["index"]
text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = choice["finish_reason"]
choices_list.append(text_choices)
text_completion_response["choices"] = choices_list
text_completion_response["usage"] = response.get("usage", None)
text_completion_response._hidden_params = HiddenParams(
**response._hidden_params
)
return text_completion_response
@staticmethod
def _convert_provider_response_logprobs_to_text_completion_logprobs(
response: ModelResponse,
custom_llm_provider: Optional[str] = None,
) -> Optional[TextCompletionLogprobs]:
"""
Convert logprobs from provider to OpenAI.Completion() format
Only supported for HF TGI models
"""
transformed_logprobs: Optional[TextCompletionLogprobs] = None
if custom_llm_provider == "huggingface":
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.huggingface._transform_logprobs(
hf_response=raw_response
)
except Exception as e:
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
return transformed_logprobs
def convert_to_model_response_object( # noqa: PLR0915
response_object: Optional[dict] = None,

View file

@ -0,0 +1,50 @@
from typing import Optional
import litellm
class Rules:
"""
Fail calls based on the input or llm api output
Example usage:
import litellm
def my_custom_rule(input): # receives the model response
if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer
return False
return True
litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user",
"content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"])
"""
def __init__(self) -> None:
pass
def pre_call_rules(self, input: str, model: str):
for rule in litellm.pre_call_rules:
if callable(rule):
decision = rule(input)
if decision is False:
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
return True
def post_call_rules(self, input: Optional[str], model: str) -> bool:
if input is None:
return True
for rule in litellm.post_call_rules:
if callable(rule):
decision = rule(input)
if isinstance(decision, bool):
if decision is False:
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
elif isinstance(decision, dict):
decision_val = decision.get("decision", True)
decision_message = decision.get(
"message", "LLM Response failed post-call-rule check"
)
if decision_val is False:
raise litellm.APIResponseValidationError(message=decision_message, llm_provider="", model=model) # type: ignore
return True

View file

@ -243,6 +243,49 @@ class ChunkProcessor:
id=id,
)
def _usage_chunk_calculation_helper(self, usage_chunk: Usage) -> dict:
prompt_tokens = 0
completion_tokens = 0
## anthropic prompt caching information ##
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
completion_tokens_details: Optional[CompletionTokensDetails] = None
prompt_tokens_details: Optional[PromptTokensDetails] = None
if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage_chunk:
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
if "cache_creation_input_tokens" in usage_chunk:
cache_creation_input_tokens = usage_chunk.get("cache_creation_input_tokens")
if "cache_read_input_tokens" in usage_chunk:
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
if hasattr(usage_chunk, "completion_tokens_details"):
if isinstance(usage_chunk.completion_tokens_details, dict):
completion_tokens_details = CompletionTokensDetails(
**usage_chunk.completion_tokens_details
)
elif isinstance(
usage_chunk.completion_tokens_details, CompletionTokensDetails
):
completion_tokens_details = usage_chunk.completion_tokens_details
if hasattr(usage_chunk, "prompt_tokens_details"):
if isinstance(usage_chunk.prompt_tokens_details, dict):
prompt_tokens_details = PromptTokensDetails(
**usage_chunk.prompt_tokens_details
)
elif isinstance(usage_chunk.prompt_tokens_details, PromptTokensDetails):
prompt_tokens_details = usage_chunk.prompt_tokens_details
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"cache_creation_input_tokens": cache_creation_input_tokens,
"cache_read_input_tokens": cache_read_input_tokens,
"completion_tokens_details": completion_tokens_details,
"prompt_tokens_details": prompt_tokens_details,
}
def calculate_usage(
self,
chunks: List[Union[Dict[str, Any], ModelResponse]],
@ -269,37 +312,30 @@ class ChunkProcessor:
elif isinstance(chunk, ModelResponse) and hasattr(chunk, "_hidden_params"):
usage_chunk = chunk._hidden_params.get("usage", None)
if usage_chunk is not None:
if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage_chunk:
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
if "cache_creation_input_tokens" in usage_chunk:
cache_creation_input_tokens = usage_chunk.get(
usage_chunk_dict = self._usage_chunk_calculation_helper(usage_chunk)
if (
usage_chunk_dict["prompt_tokens"] is not None
and usage_chunk_dict["prompt_tokens"] > 0
):
prompt_tokens = usage_chunk_dict["prompt_tokens"]
if (
usage_chunk_dict["completion_tokens"] is not None
and usage_chunk_dict["completion_tokens"] > 0
):
completion_tokens = usage_chunk_dict["completion_tokens"]
if usage_chunk_dict["cache_creation_input_tokens"] is not None:
cache_creation_input_tokens = usage_chunk_dict[
"cache_creation_input_tokens"
)
if "cache_read_input_tokens" in usage_chunk:
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
if hasattr(usage_chunk, "completion_tokens_details"):
if isinstance(usage_chunk.completion_tokens_details, dict):
completion_tokens_details = CompletionTokensDetails(
**usage_chunk.completion_tokens_details
)
elif isinstance(
usage_chunk.completion_tokens_details, CompletionTokensDetails
):
completion_tokens_details = (
usage_chunk.completion_tokens_details
)
if hasattr(usage_chunk, "prompt_tokens_details"):
if isinstance(usage_chunk.prompt_tokens_details, dict):
prompt_tokens_details = PromptTokensDetails(
**usage_chunk.prompt_tokens_details
)
elif isinstance(
usage_chunk.prompt_tokens_details, PromptTokensDetails
):
prompt_tokens_details = usage_chunk.prompt_tokens_details
]
if usage_chunk_dict["cache_read_input_tokens"] is not None:
cache_read_input_tokens = usage_chunk_dict[
"cache_read_input_tokens"
]
if usage_chunk_dict["completion_tokens_details"] is not None:
completion_tokens_details = usage_chunk_dict[
"completion_tokens_details"
]
prompt_tokens_details = usage_chunk_dict["prompt_tokens_details"]
try:
returned_usage.prompt_tokens = prompt_tokens or token_counter(
model=model, messages=messages

File diff suppressed because it is too large Load diff

View file

@ -1,14 +0,0 @@
from litellm.types.utils import GenericStreamingChunk as GChunk
def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
"""
Checks if the provided chunk dictionary contains all required fields for GenericStreamingChunk.
:param chunk: The dictionary to check.
:return: True if all required fields are present, False otherwise.
"""
_all_fields = GChunk.__annotations__
decision = all(key in _all_fields for key in chunk)
return decision

View file

@ -3,7 +3,7 @@ Support for gpt model family
"""
import types
from typing import Optional, Union
from typing import List, Optional, Union
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
@ -94,6 +94,7 @@ class OpenAIGPTConfig:
"max_tokens",
"max_completion_tokens",
"modalities",
"prediction",
"n",
"presence_penalty",
"seed",
@ -162,3 +163,8 @@ class OpenAIGPTConfig:
model=model,
drop_params=drop_params,
)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages

View file

@ -108,7 +108,9 @@ class OpenAIO1Config(OpenAIGPTConfig):
return True
return False
def o1_prompt_factory(self, messages: List[AllMessageValues]):
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
Handles limitations of O-1 model family.
- modalities: image => drop param (if user opts in to dropping param)

View file

@ -15,6 +15,7 @@ from pydantic import BaseModel
from typing_extensions import overload, override
import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str
@ -24,6 +25,7 @@ from litellm.utils import (
CustomStreamWrapper,
Message,
ModelResponse,
ProviderConfigManager,
TextCompletionResponse,
Usage,
convert_to_model_response_object,
@ -701,13 +703,11 @@ class OpenAIChatCompletion(BaseLLM):
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if (
litellm.openAIO1Config.is_model_o1_reasoning_model(model=model)
and messages is not None
):
messages = litellm.openAIO1Config.o1_prompt_factory(
messages=messages,
if messages is not None and custom_llm_provider is not None:
provider_config = ProviderConfigManager.get_provider_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
messages = provider_config._transform_messages(messages)
for _ in range(
2

View file

@ -71,11 +71,12 @@ def validate_environment(
prompt_caching_set = AnthropicConfig().is_cache_control_set(messages=messages)
computer_tool_used = AnthropicConfig().is_computer_tool_used(tools=tools)
pdf_used = AnthropicConfig().is_pdf_used(messages=messages)
headers = AnthropicConfig().get_anthropic_headers(
anthropic_version=anthropic_version,
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
pdf_used=pdf_used,
api_key=api_key,
)
@ -769,6 +770,7 @@ class ModelResponseIterator:
message=message,
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
)
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,

View file

@ -104,6 +104,7 @@ class AnthropicConfig:
anthropic_version: Optional[str] = None,
computer_tool_used: bool = False,
prompt_caching_set: bool = False,
pdf_used: bool = False,
) -> dict:
import json
@ -112,6 +113,8 @@ class AnthropicConfig:
betas.append("prompt-caching-2024-07-31")
if computer_tool_used:
betas.append("computer-use-2024-10-22")
if pdf_used:
betas.append("pdfs-2024-09-25")
headers = {
"anthropic-version": anthropic_version or "2023-06-01",
"x-api-key": api_key,
@ -365,6 +368,21 @@ class AnthropicConfig:
return True
return False
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
"""
Set to true if media passed into messages.
"""
for message in messages:
if (
"content" in message
and message["content"] is not None
and isinstance(message["content"], list)
):
for content in message["content"]:
if "type" in content:
return True
return False
def translate_system_message(
self, messages: List[AllMessageValues]
) -> List[AnthropicSystemMessageContent]:

View file

@ -1,16 +1,28 @@
import hashlib
import json
import os
from typing import Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import httpx
from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache, InMemoryCache
from litellm.secret_managers.main import get_secret
from litellm.secret_managers.main import get_secret, get_secret_str
from .base import BaseLLM
if TYPE_CHECKING:
from botocore.credentials import Credentials
else:
Credentials = Any
class Boto3CredentialsInfo(BaseModel):
credentials: Credentials
aws_region_name: str
aws_bedrock_runtime_endpoint: Optional[str]
class AwsAuthError(Exception):
def __init__(self, status_code, message):
@ -311,3 +323,74 @@ class BaseAWSLLM(BaseLLM):
proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url
def _get_boto_credentials_from_optional_params(
self, optional_params: dict
) -> Boto3CredentialsInfo:
"""
Get boto3 credentials from optional params
Args:
optional_params (dict): Optional parameters for the model call
Returns:
Credentials: Boto3 credentials object
"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret_str("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
return Boto3CredentialsInfo(
credentials=credentials,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
)

View file

@ -19,6 +19,7 @@ from ..common_utils import BedrockError
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-3-5-haiku-20241022-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",

View file

@ -484,73 +484,6 @@ class AmazonMistralConfig:
}
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def add_custom_header(headers):
"""Closure to capture the headers and add them."""

View file

@ -0,0 +1,104 @@
import types
from typing import List, Optional
from openai.types.image import Image
from litellm.types.utils import ImageResponse
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
return ["size"]
@classmethod
def map_openai_params(
cls,
non_default_params: dict,
optional_params: dict,
):
_size = non_default_params.get("size")
if _size is not None:
width, height = _size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
return optional_params
@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
image_list: List[Image] = []
for artifact in response_dict["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response

View file

@ -0,0 +1,94 @@
import types
from typing import List, Optional
from openai.types.image import Image
from litellm.types.llms.bedrock import (
AmazonStability3TextToImageRequest,
AmazonStability3TextToImageResponse,
)
from litellm.types.utils import ImageResponse
class AmazonStability3Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Stability API Ref: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
"""
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
"""
No additional OpenAI params are mapped for stability 3
"""
return []
@classmethod
def _is_stability_3_model(cls, model: Optional[str] = None) -> bool:
"""
Returns True if the model is a Stability 3 model
Stability 3 models follow this pattern:
sd3-large
sd3-large-turbo
sd3-medium
sd3.5-large
sd3.5-large-turbo
"""
if model and ("sd3" in model or "sd3.5" in model):
return True
return False
@classmethod
def transform_request_body(
cls, prompt: str, optional_params: dict
) -> AmazonStability3TextToImageRequest:
"""
Transform the request body for the Stability 3 models
"""
data = AmazonStability3TextToImageRequest(prompt=prompt, **optional_params)
return data
@classmethod
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
"""
Map the OpenAI params to the Bedrock params
No OpenAI params are mapped for Stability 3, so directly return the optional_params
"""
return optional_params
@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
"""
Transform the response dict to the OpenAI response
"""
stability_3_response = AmazonStability3TextToImageResponse(**response_dict)
openai_images: List[Image] = []
for _img in stability_3_response.get("images", []):
openai_images.append(Image(b64_json=_img))
model_response.data = openai_images
return model_response

View file

@ -0,0 +1,41 @@
from typing import Optional
import litellm
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: ImageResponse,
size: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> float:
"""
Bedrock image generation cost calculator
Handles both Stability 1 and Stability 3 models
"""
if litellm.AmazonStability3Config()._is_stability_3_model(model=model):
pass
else:
# Stability 1 models
optional_params = optional_params or {}
# see model_prices_and_context_window.json for details on how steps is used
# Reference pricing by steps for stability 1: https://aws.amazon.com/bedrock/pricing/
_steps = optional_params.get("steps", 50)
steps = "max-steps" if _steps > 50 else "50-steps"
# size is stored in model_prices_and_context_window.json as 1024-x-1024
# current size has 1024x1024
size = size or "1024-x-1024"
model = f"{size}/{steps}/{model}"
_model_info = litellm.get_model_info(
model=model,
custom_llm_provider="bedrock",
)
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
num_images: int = len(image_response.data)
return output_cost_per_image * num_images

View file

@ -0,0 +1,304 @@
import copy
import json
import os
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from openai.types.image import Image
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import ImageResponse
from ...base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
class BedrockImagePreparedRequest(BaseModel):
"""
Internal/Helper class for preparing the request for bedrock image generation
"""
endpoint_url: str
prepped: AWSPreparedRequest
body: bytes
data: dict
class BedrockImageGeneration(BaseAWSLLM):
"""
Bedrock Image Generation handler
"""
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: LitellmLogging,
timeout: Optional[Union[float, httpx.Timeout]],
aimg_generation: bool = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
):
prepared_request = self._prepare_request(
model=model,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
logging_obj=logging_obj,
prompt=prompt,
)
if aimg_generation is True:
return self.async_image_generation(
prepared_request=prepared_request,
timeout=timeout,
model=model,
logging_obj=logging_obj,
prompt=prompt,
model_response=model_response,
)
client = _get_httpx_client()
try:
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model_response=model_response,
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
)
return model_response
async def async_image_generation(
self,
prepared_request: BedrockImagePreparedRequest,
timeout: Optional[Union[float, httpx.Timeout]],
model: str,
logging_obj: LitellmLogging,
prompt: str,
model_response: ImageResponse,
) -> ImageResponse:
"""
Asynchronous handler for bedrock image generation
Awaits the response from the bedrock image generation endpoint
"""
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK,
params={"timeout": timeout},
)
try:
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
### FORMAT RESPONSE TO OPENAI FORMAT ###
model_response = self._transform_response_dict_to_openai_response(
model=model,
logging_obj=logging_obj,
prompt=prompt,
response=response,
data=prepared_request.data,
model_response=model_response,
)
return model_response
def _prepare_request(
self,
model: str,
optional_params: dict,
api_base: Optional[str],
extra_headers: Optional[dict],
logging_obj: LitellmLogging,
prompt: str,
) -> BedrockImagePreparedRequest:
"""
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
Args:
model (str): The model to use for the image generation
optional_params (dict): The optional parameters for the image generation
api_base (Optional[str]): The base URL for the Bedrock API
extra_headers (Optional[dict]): The extra headers to include in the request
logging_obj (LitellmLogging): The logging object to use for logging
prompt (str): The prompt to use for the image generation
Returns:
BedrockImagePreparedRequest: The prepared request object
The BedrockImagePreparedRequest contains:
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
prepped (httpx.Request): The prepared request object
body (bytes): The request body
"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
### SET RUNTIME ENDPOINT ###
modelId = model
_, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
sigv4 = SigV4Auth(
boto3_credentials_info.credentials,
"bedrock",
boto3_credentials_info.aws_region_name,
)
data = self._get_request_body(
model=model, prompt=prompt, optional_params=optional_params
)
# Make POST Request
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=proxy_endpoint_url, data=body, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
return BedrockImagePreparedRequest(
endpoint_url=proxy_endpoint_url,
prepped=prepped,
body=body,
data=data,
)
def _get_request_body(
self,
model: str,
prompt: str,
optional_params: dict,
) -> dict:
"""
Get the request body for the Bedrock Image Generation API
Checks the model/provider and transforms the request body accordingly
Returns:
dict: The request body to use for the Bedrock Image Generation API
"""
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
if litellm.AmazonStability3Config._is_stability_3_model(model):
request_body = litellm.AmazonStability3Config.transform_request_body(
prompt=prompt, optional_params=optional_params
)
return dict(request_body)
else:
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {
"text_prompts": [{"text": prompt, "weight": 1}],
**inference_params,
}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
return data
def _transform_response_dict_to_openai_response(
self,
model_response: ImageResponse,
model: str,
logging_obj: LitellmLogging,
prompt: str,
response: httpx.Response,
data: dict,
) -> ImageResponse:
"""
Transforms the Image Generation response from Bedrock to OpenAI format
"""
## LOGGING
if logging_obj is not None:
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
verbose_logger.debug("raw model_response: %s", response.text)
response_dict = response.json()
if response_dict is None:
raise ValueError("Error in response object format, got None")
config_class = (
litellm.AmazonStability3Config
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
else litellm.AmazonStabilityConfig
)
config_class.transform_response_dict_to_openai_response(
model_response=model_response,
response_dict=response_dict,
)
return model_response

View file

@ -1,127 +0,0 @@
"""
Handles image gen calls to Bedrock's `/invoke` endpoint
"""
import copy
import json
import os
from typing import Any, List
from openai.types.image import Image
import litellm
from litellm.types.utils import ImageResponse
from .common_utils import BedrockError, init_bedrock_client
def image_generation(
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
aimg_generation=False,
):
"""
Bedrock Image Gen endpoint support
"""
### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
timeout=timeout,
)
### FORMAT IMAGE GENERATION INPUT ###
modelId = model
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
body={body}, # type: ignore
modelId={modelId},
accept="application/json",
contentType="application/json",
)""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="", # boto3 is used for init.
additional_args={
"complete_input_dict": {"model": modelId, "texts": prompt},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
body=body,
modelId=modelId,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
except Exception as e:
raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
### FORMAT RESPONSE TO OPENAI FORMAT ###
if response_body is None:
raise Exception("Error in response object format")
if model_response is None:
model_response = ImageResponse()
image_list: List[Image] = []
for artifact in response_body["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response

View file

@ -34,12 +34,14 @@ class AsyncHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None,
concurrent_limit=1000,
client_alias: Optional[str] = None, # name for client in logs
):
self.timeout = timeout
self.event_hooks = event_hooks
self.client = self.create_client(
timeout=timeout, concurrent_limit=concurrent_limit, event_hooks=event_hooks
)
self.client_alias = client_alias
def create_client(
self,
@ -112,6 +114,7 @@ class AsyncHTTPHandler:
try:
if timeout is None:
timeout = self.timeout
req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
)

View file

@ -1,7 +1,8 @@
import json
from typing import Optional
from typing import List, Optional
import litellm
from litellm import verbose_logger
from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
@ -9,7 +10,7 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import GenericStreamingChunk, ModelResponse, Usage
class ModelResponseIterator:
@ -109,7 +110,17 @@ class ModelResponseIterator:
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
verbose_logger.debug(
f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here."
)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
# Async iterator
def __aiter__(self):
@ -123,6 +134,8 @@ class ModelResponseIterator:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
except Exception as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
@ -144,4 +157,14 @@ class ModelResponseIterator:
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
verbose_logger.debug(
f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here."
)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)

View file

@ -0,0 +1,41 @@
"""
Translates from OpenAI's `/v1/chat/completions` to DeepSeek's `/v1/chat/completions`
"""
import types
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel
import litellm
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
from ...prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
)
class DeepSeekChatConfig(OpenAIGPTConfig):
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
DeepSeek does not support content in list format.
"""
messages = handle_messages_with_content_list_to_str_conversion(messages)
return super()._transform_messages(messages)
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = (
api_base
or get_secret_str("DEEPSEEK_API_BASE")
or "https://api.deepseek.com/beta"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY")
return api_base, dynamic_api_key

View file

@ -15,6 +15,7 @@ import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str
from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from .base import BaseLLM
@ -1183,3 +1184,73 @@ class Huggingface(BaseLLM):
input=input,
encoding=encoding,
)
def _transform_logprobs(
self, hf_response: Optional[List]
) -> Optional[TextCompletionLogprobs]:
"""
Transform Hugging Face logprobs to OpenAI.Completion() format
"""
if hf_response is None:
return None
# Initialize an empty list for the transformed logprobs
_logprob: TextCompletionLogprobs = TextCompletionLogprobs(
text_offset=[],
token_logprobs=[],
tokens=[],
top_logprobs=[],
)
# For each Hugging Face response, transform the logprobs
for response in hf_response:
# Extract the relevant information from the response
response_details = response["details"]
top_tokens = response_details.get("top_tokens", {})
for i, token in enumerate(response_details["prefill"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
# Add the token information to the 'token_info' list
_logprob.tokens.append(token_text)
_logprob.token_logprobs.append(token_logprob)
# stub this to work with llm eval harness
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
_logprob.top_logprobs.append(top_alt_tokens)
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details["tokens"]):
# Extract the text of the token
token_text = token["text"]
# Extract the logprob of the token
token_logprob = token["logprob"]
top_alt_tokens = {}
temp_top_logprobs = []
if top_tokens != {}:
temp_top_logprobs = top_tokens[i]
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
for elem in temp_top_logprobs:
text = elem["text"]
logprob = elem["logprob"]
top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list
_logprob.tokens.append(token_text)
_logprob.token_logprobs.append(token_logprob)
_logprob.top_logprobs.append(top_alt_tokens)
# Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens
_logprob.text_offset.append(
sum(len(t["text"]) for t in response_details["tokens"][:i])
)
return _logprob

View file

@ -10,6 +10,7 @@ import types
from typing import List, Literal, Optional, Tuple, Union
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
class MistralConfig:
@ -148,3 +149,59 @@ class MistralConfig:
or get_secret_str("MISTRAL_API_KEY")
)
return api_base, dynamic_api_key
@classmethod
def _transform_messages(cls, messages: List[AllMessageValues]):
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
- if `name` is passed, then drop it for mistral API: https://github.com/BerriAI/litellm/issues/6696
Motivation: mistral api doesn't support content as a list
"""
new_messages = []
for m in messages:
special_keys = ["role", "content", "tool_calls", "function_call"]
extra_args = {}
if isinstance(m, dict):
for k, v in m.items():
if k not in special_keys:
extra_args[k] = v
texts = ""
_content = m.get("content")
if _content is not None and isinstance(_content, list):
for c in _content:
_text: Optional[str] = c.get("text")
if c["type"] == "image_url":
return messages
elif c["type"] == "text" and isinstance(_text, str):
texts += _text
elif _content is not None and isinstance(_content, str):
texts = _content
new_m = {"role": m["role"], "content": texts, **extra_args}
if m.get("tool_calls"):
new_m["tool_calls"] = m.get("tool_calls")
new_m = cls._handle_name_in_message(new_m)
new_messages.append(new_m)
return new_messages
@classmethod
def _handle_name_in_message(cls, message: dict) -> dict:
"""
Mistral API only supports `name` in tool messages
If role == tool, then we keep `name`
Otherwise, we drop `name`
"""
if message.get("name") is not None:
if message["role"] == "tool":
message["name"] = message.get("name")
else:
message.pop("name", None)
return message

View file

@ -0,0 +1,372 @@
"""
OpenAI-like chat completion handler
For handling OpenAI-like chat completions, like IBM WatsonX, etc.
"""
import copy
import json
import os
import time
import types
from enum import Enum
from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
from litellm.utils import CustomStreamWrapper, EmbeddingResponse
from ..common_utils import OpenAILikeBase, OpenAILikeError
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
):
if client is None:
client = litellm.module_level_aclient
response = await client.post(api_base, headers=headers, data=data, stream=True)
if streaming_decoder is not None:
completion_stream: Any = streaming_decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
):
if client is None:
client = litellm.module_level_client # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise OpenAILikeError(status_code=response.status_code, message=response.read())
if streaming_decoder is not None:
completion_stream = streaming_decoder.iter_bytes(
response.iter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
class OpenAILikeChatHandler(OpenAILikeBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def acompletion_stream_function(
self,
model: str,
messages: list,
custom_llm_provider: str,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
data: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
) -> CustomStreamWrapper:
data["stream"] = True
completion_stream = await make_call(
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streamwrapper
async def acompletion_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
custom_llm_provider: str,
print_verbose: Callable,
client: Optional[AsyncHTTPHandler],
encoding,
api_key,
logging_obj,
stream,
data: dict,
base_model: Optional[str],
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> ModelResponse:
if timeout is None:
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
if client is None:
client = litellm.module_level_aclient
try:
response = await client.post(
api_base, headers=headers, data=json.dumps(data), timeout=timeout
)
response.raise_for_status()
response_json = response.json()
except httpx.HTTPStatusError as e:
raise OpenAILikeError(
status_code=e.response.status_code,
message=e.response.text,
)
except httpx.TimeoutException:
raise OpenAILikeError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise OpenAILikeError(status_code=500, message=str(e))
logging_obj.post_call(
input=messages,
api_key="",
original_response=response_json,
additional_args={"complete_input_dict": data},
)
response = ModelResponse(**response_json)
response.model = custom_llm_provider + "/" + (response.model or "")
if base_model is not None:
response._hidden_params["model"] = base_model
return response
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key: Optional[str],
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[
CustomStreamingDecoder
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
):
custom_endpoint = custom_endpoint or optional_params.pop(
"custom_endpoint", None
)
base_model: Optional[str] = optional_params.pop("base_model", None)
api_base, headers = self._validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="chat_completions",
custom_endpoint=custom_endpoint,
headers=headers,
)
stream: bool = optional_params.get("stream", None) or False
optional_params["stream"] = stream
data = {
"model": model,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if acompletion is True:
if client is None or not isinstance(client, AsyncHTTPHandler):
client = None
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
client=client,
custom_llm_provider=custom_llm_provider,
streaming_decoder=streaming_decoder,
)
else:
return self.acompletion_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
custom_llm_provider=custom_llm_provider,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
base_model=base_model,
client=client,
)
else:
## COMPLETION CALL
if stream is True:
completion_stream = make_sync_call(
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
)
# completion_stream.__iter__()
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
else:
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(timeout=timeout) # type: ignore
try:
response = client.post(
api_base, headers=headers, data=json.dumps(data)
)
response.raise_for_status()
response_json = response.json()
except httpx.HTTPStatusError as e:
raise OpenAILikeError(
status_code=e.response.status_code,
message=e.response.text,
)
except httpx.TimeoutException:
raise OpenAILikeError(
status_code=408, message="Timeout error occurred."
)
except Exception as e:
raise OpenAILikeError(status_code=500, message=str(e))
logging_obj.post_call(
input=messages,
api_key="",
original_response=response_json,
additional_args={"complete_input_dict": data},
)
response = ModelResponse(**response_json)
response.model = custom_llm_provider + "/" + (response.model or "")
if base_model is not None:
response._hidden_params["model"] = base_model
return response

View file

@ -1,3 +1,5 @@
from typing import Literal, Optional, Tuple
import httpx
@ -10,3 +12,43 @@ class OpenAILikeError(Exception):
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class OpenAILikeBase:
def __init__(self, **kwargs):
pass
def _validate_environment(
self,
api_key: Optional[str],
api_base: Optional[str],
endpoint_type: Literal["chat_completions", "embeddings"],
headers: Optional[dict],
custom_endpoint: Optional[bool],
) -> Tuple[str, dict]:
if api_key is None and headers is None:
raise OpenAILikeError(
status_code=400,
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
if api_base is None:
raise OpenAILikeError(
status_code=400,
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
if headers is None:
headers = {
"Content-Type": "application/json",
}
if api_key is not None:
headers.update({"Authorization": "Bearer {}".format(api_key)})
if not custom_endpoint:
if endpoint_type == "chat_completions":
api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings":
api_base = "{}/embeddings".format(api_base)
return api_base, headers

View file

@ -23,46 +23,13 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.utils import EmbeddingResponse
from ..common_utils import OpenAILikeError
from ..common_utils import OpenAILikeBase, OpenAILikeError
class OpenAILikeEmbeddingHandler:
class OpenAILikeEmbeddingHandler(OpenAILikeBase):
def __init__(self, **kwargs):
pass
def _validate_environment(
self,
api_key: Optional[str],
api_base: Optional[str],
endpoint_type: Literal["chat_completions", "embeddings"],
headers: Optional[dict],
) -> Tuple[str, dict]:
if api_key is None and headers is None:
raise OpenAILikeError(
status_code=400,
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
if api_base is None:
raise OpenAILikeError(
status_code=400,
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
if headers is None:
headers = {
"Content-Type": "application/json",
}
if api_key is not None:
headers.update({"Authorization": "Bearer {}".format(api_key)})
if endpoint_type == "chat_completions":
api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings":
api_base = "{}/embeddings".format(api_base)
return api_base, headers
async def aembedding(
self,
input: list,
@ -133,6 +100,7 @@ class OpenAILikeEmbeddingHandler:
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
client=None,
aembedding=None,
custom_endpoint: Optional[bool] = None,
headers: Optional[dict] = None,
) -> EmbeddingResponse:
api_base, headers = self._validate_environment(
@ -140,6 +108,7 @@ class OpenAILikeEmbeddingHandler:
api_key=api_key,
endpoint_type="embeddings",
headers=headers,
custom_endpoint=custom_endpoint,
)
model = model
data = {"model": model, "input": input, **optional_params}

View file

@ -24,6 +24,19 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
)
def handle_messages_with_content_list_to_str_conversion(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Handles messages with content list conversion
"""
for message in messages:
texts = convert_content_list_to_str(message=message)
if texts:
message["content"] = texts
return messages
def convert_content_list_to_str(message: AllMessageValues) -> str:
"""
- handles scenario where content is list and not string

View file

@ -259,43 +259,6 @@ def mistral_instruct_pt(messages):
return prompt
def mistral_api_pt(messages):
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
Motivation: mistral api doesn't support content as a list
"""
new_messages = []
for m in messages:
special_keys = ["role", "content", "tool_calls", "function_call"]
extra_args = {}
if isinstance(m, dict):
for k, v in m.items():
if k not in special_keys:
extra_args[k] = v
texts = ""
if m.get("content", None) is not None and isinstance(m["content"], list):
for c in m["content"]:
if c["type"] == "image_url":
return messages
elif c["type"] == "text" and isinstance(c["text"], str):
texts += c["text"]
elif m.get("content", None) is not None and isinstance(m["content"], str):
texts = m["content"]
new_m = {"role": m["role"], "content": texts, **extra_args}
if new_m["role"] == "tool" and m.get("name"):
new_m["name"] = m["name"]
if m.get("tool_calls"):
new_m["tool_calls"] = m["tool_calls"]
new_messages.append(new_m)
return new_messages
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
prompt = ""
@ -1330,7 +1293,10 @@ def convert_to_anthropic_tool_invoke(
def add_cache_control_to_content(
anthropic_content_element: Union[
dict, AnthropicMessagesImageParam, AnthropicMessagesTextParam
dict,
AnthropicMessagesImageParam,
AnthropicMessagesTextParam,
AnthropicMessagesDocumentParam,
],
orignal_content_element: Union[dict, AllMessageValues],
):
@ -1343,6 +1309,32 @@ def add_cache_control_to_content(
return anthropic_content_element
def _anthropic_content_element_factory(
image_chunk: GenericImageParsingChunk,
) -> Union[AnthropicMessagesImageParam, AnthropicMessagesDocumentParam]:
if image_chunk["media_type"] == "application/pdf":
_anthropic_content_element: Union[
AnthropicMessagesDocumentParam, AnthropicMessagesImageParam
] = AnthropicMessagesDocumentParam(
type="document",
source=AnthropicContentParamSource(
type="base64",
media_type=image_chunk["media_type"],
data=image_chunk["data"],
),
)
else:
_anthropic_content_element = AnthropicMessagesImageParam(
type="image",
source=AnthropicContentParamSource(
type="base64",
media_type=image_chunk["media_type"],
data=image_chunk["data"],
),
)
return _anthropic_content_element
def anthropic_messages_pt( # noqa: PLR0915
messages: List[AllMessageValues],
model: str,
@ -1400,15 +1392,9 @@ def anthropic_messages_pt( # noqa: PLR0915
openai_image_url=m["image_url"]["url"]
)
_anthropic_content_element = AnthropicMessagesImageParam(
type="image",
source=AnthropicImageParamSource(
type="base64",
media_type=image_chunk["media_type"],
data=image_chunk["data"],
),
_anthropic_content_element = (
_anthropic_content_element_factory(image_chunk)
)
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_content_element,
orignal_content_element=dict(m),
@ -2830,7 +2816,7 @@ def prompt_factory(
else:
return gemini_text_image_pt(messages=messages)
elif custom_llm_provider == "mistral":
return mistral_api_pt(messages=messages)
return litellm.MistralConfig._transform_messages(messages=messages)
elif custom_llm_provider == "bedrock":
if "amazon.titan-text" in model:
return amazon_titan_pt(messages=messages)

View file

@ -51,6 +51,9 @@ from ..common_utils import (
def _process_gemini_image(image_url: str) -> PartType:
"""
Given an image URL, return the appropriate PartType for Gemini
"""
try:
# GCS URIs
if "gs://" in image_url:
@ -68,9 +71,14 @@ def _process_gemini_image(image_url: str) -> PartType:
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
return PartType(file_data=file_data)
# Direct links
elif "https:/" in image_url or "base64" in image_url:
elif (
"https://" in image_url
and (image_type := _get_image_mime_type_from_url(image_url)) is not None
):
file_data = FileDataType(file_uri=image_url, mime_type=image_type)
return PartType(file_data=file_data)
elif "https://" in image_url or "base64" in image_url:
# https links for unsupported mime types and base64 images
image = convert_to_anthropic_image_obj(image_url)
_blob = BlobType(data=image["data"], mime_type=image["media_type"])
return PartType(inline_data=_blob)
@ -79,6 +87,29 @@ def _process_gemini_image(image_url: str) -> PartType:
raise e
def _get_image_mime_type_from_url(url: str) -> Optional[str]:
"""
Get mime type for common image URLs
See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements
Supported by Gemini:
- PNG (`image/png`)
- JPEG (`image/jpeg`)
- WebP (`image/webp`)
Example:
url = https://example.com/image.jpg
Returns: image/jpeg
"""
url = url.lower()
if url.endswith((".jpg", ".jpeg")):
return "image/jpeg"
elif url.endswith(".png"):
return "image/png"
elif url.endswith(".webp"):
return "image/webp"
return None
def _gemini_convert_messages_with_history( # noqa: PLR0915
messages: List[AllMessageValues],
) -> List[ContentType]:

View file

@ -0,0 +1,123 @@
from typing import Callable, Optional, Union
import httpx
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
from ...openai_like.chat.handler import OpenAILikeChatHandler
from ..common_utils import WatsonXAIError, _get_api_params
class WatsonXChatHandler(OpenAILikeChatHandler):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _prepare_url(
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
) -> str:
if model.startswith("deployment/"):
if api_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
url=api_params["url"],
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:])
endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
if stream is True
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
)
endpoint = endpoint.format(deployment_id=deployment_id)
else:
endpoint = (
WatsonXAIEndpoint.CHAT_STREAM.value
if stream is True
else WatsonXAIEndpoint.CHAT.value
)
base_url = httpx.URL(api_params["url"])
base_url = base_url.join(endpoint)
full_url = str(
base_url.copy_add_param(key="version", value=api_params["api_version"])
)
return full_url
def _prepare_payload(
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
) -> dict:
payload: dict = {}
if model.startswith("deployment/"):
return payload
payload["model_id"] = model
payload["project_id"] = api_params["project_id"]
return payload
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key: Optional[str],
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[
CustomStreamingDecoder
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
):
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
if headers is None:
headers = {}
headers.update(
{
"Authorization": f"Bearer {api_params['token']}",
"Content-Type": "application/json",
"Accept": "application/json",
}
)
stream: Optional[bool] = optional_params.get("stream", False)
## get api url and payload
api_base = self._prepare_url(model=model, api_params=api_params, stream=stream)
watsonx_auth_payload = self._prepare_payload(
model=model, api_params=api_params, stream=stream
)
optional_params.update(watsonx_auth_payload)
return super().completion(
model=model,
messages=messages,
api_base=api_base,
custom_llm_provider=custom_llm_provider,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
custom_endpoint=True,
streaming_decoder=streaming_decoder,
)

View file

@ -0,0 +1,82 @@
"""
Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint.
Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
"""
import types
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel
import litellm
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
class IBMWatsonXChatConfig(OpenAIGPTConfig):
def get_supported_openai_params(self, model: str) -> List:
return [
"temperature", # equivalent to temperature
"max_tokens", # equivalent to max_new_tokens
"top_p", # equivalent to top_p
"frequency_penalty", # equivalent to repetition_penalty
"stop", # equivalent to stop_sequences
"seed", # equivalent to random_seed
"stream", # equivalent to stream
"tools",
"tool_choice", # equivalent to tool_choice + tool_choice_options
"logprobs",
"top_logprobs",
"n",
"presence_penalty",
"response_format",
]
def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool:
if tool_choice is None:
return False
if isinstance(tool_choice, str):
return tool_choice in ["auto", "none", "required"]
return False
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
## TOOLS ##
_tools = non_default_params.pop("tools", None)
if _tools is not None:
# remove 'additionalProperties' from tools
_tools = _remove_additional_properties(_tools)
# remove 'strict' from tools
_tools = _remove_strict_from_schema(_tools)
if _tools is not None:
non_default_params["tools"] = _tools
## TOOL CHOICE ##
_tool_choice = non_default_params.pop("tool_choice", None)
if self.is_tool_choice_option(_tool_choice):
optional_params["tool_choice_options"] = _tool_choice
elif _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
return super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore
dynamic_api_key = (
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
) # vllm does not require an api key
return api_base, dynamic_api_key

View file

@ -0,0 +1,172 @@
from typing import Callable, Optional, cast
import httpx
import litellm
from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAPIParams
class WatsonXAIError(Exception):
def __init__(self, status_code, message, url: Optional[str] = None):
self.status_code = status_code
self.message = message
url = url or "https://https://us-south.ml.cloud.ibm.com"
self.request = httpx.Request(method="POST", url=url)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
iam_token_cache = InMemoryCache()
def generate_iam_token(api_key=None, **params) -> str:
result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore
if result is None:
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
if api_key is None:
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
if api_key is None:
raise ValueError("API key is required")
headers["Accept"] = "application/json"
data = {
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
"apikey": api_key,
}
verbose_logger.debug(
"calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
"https://iam.cloud.ibm.com/identity/token",
headers,
data,
)
response = httpx.post(
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
)
response.raise_for_status()
json_data = response.json()
result = json_data["access_token"]
iam_token_cache.set_cache(
key=api_key,
value=result,
ttl=json_data["expires_in"] - 10, # leave some buffer
)
return cast(str, result)
def _get_api_params(
params: dict,
print_verbose: Optional[Callable] = None,
generate_token: Optional[bool] = True,
) -> WatsonXAPIParams:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
# Load auth variables from params
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
api_key = params.pop("apikey", None)
token = params.pop("token", None)
project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None))
if region_name is None:
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
wx_credentials = params.pop(
"wx_credentials",
params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
# Load auth variables from environment variables
if url is None:
url = (
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
if api_key is None:
api_key = (
get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
if token is None:
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
if project_id is None:
project_id = (
get_secret_str("WATSONX_PROJECT_ID")
or get_secret_str("WX_PROJECT_ID")
or get_secret_str("PROJECT_ID")
)
if region_name is None:
region_name = (
get_secret_str("WATSONX_REGION")
or get_secret_str("WX_REGION")
or get_secret_str("REGION")
)
if space_id is None:
space_id = (
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret_str("WATSONX_SPACE_ID")
or get_secret_str("WX_SPACE_ID")
or get_secret_str("SPACE_ID")
)
# credentials parsing
if wx_credentials is not None:
url = wx_credentials.get("url", url)
api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", token
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
# verify that all required credentials are present
if url is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
)
if token is None and api_key is not None and generate_token:
# generate the auth token
if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai")
token = generate_iam_token(api_key)
elif token is None and api_key is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
)
if project_id is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
)
return WatsonXAPIParams(
url=url,
api_key=api_key,
token=cast(str, token),
project_id=project_id,
space_id=space_id,
region_name=region_name,
api_version=api_version,
)

View file

@ -26,22 +26,12 @@ import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from .base import BaseLLM
from .prompt_templates import factory as ptf
class WatsonXAIError(Exception):
def __init__(self, status_code, message, url: Optional[str] = None):
self.status_code = status_code
self.message = message
url = url or "https://https://us-south.ml.cloud.ibm.com"
self.request = httpx.Request(method="POST", url=url)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
from ...base import BaseLLM
from ...prompt_templates import factory as ptf
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
class IBMWatsonXAIConfig:
@ -140,6 +130,29 @@ class IBMWatsonXAIConfig:
and v is not None
}
def is_watsonx_text_param(self, param: str) -> bool:
"""
Determine if user passed in a watsonx.ai text generation param
"""
text_generation_params = [
"decoding_method",
"max_new_tokens",
"min_new_tokens",
"length_penalty",
"stop_sequences",
"top_k",
"repetition_penalty",
"truncate_input_tokens",
"include_stop_sequences",
"return_options",
"random_seed",
"moderations",
"decoding_method",
"min_tokens",
]
return param in text_generation_params
def get_supported_openai_params(self):
return [
"temperature", # equivalent to temperature
@ -151,6 +164,44 @@ class IBMWatsonXAIConfig:
"stream", # equivalent to stream
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
extra_body = {}
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_new_tokens"] = v
elif k == "stream":
optional_params["stream"] = v
elif k == "temperature":
optional_params["temperature"] = v
elif k == "top_p":
optional_params["top_p"] = v
elif k == "frequency_penalty":
optional_params["repetition_penalty"] = v
elif k == "seed":
optional_params["random_seed"] = v
elif k == "stop":
optional_params["stop_sequences"] = v
elif k == "decoding_method":
extra_body["decoding_method"] = v
elif k == "min_tokens":
extra_body["min_new_tokens"] = v
elif k == "top_k":
extra_body["top_k"] = v
elif k == "truncate_input_tokens":
extra_body["truncate_input_tokens"] = v
elif k == "length_penalty":
extra_body["length_penalty"] = v
elif k == "time_limit":
extra_body["time_limit"] = v
elif k == "return_options":
extra_body["return_options"] = v
if extra_body:
optional_params["extra_body"] = extra_body
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
@ -212,18 +263,6 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) ->
return prompt
class WatsonXAIEndpoint(str, Enum):
TEXT_GENERATION = "/ml/v1/text/generation"
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation"
DEPLOYMENT_TEXT_GENERATION_STREAM = (
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
)
EMBEDDINGS = "/ml/v1/text/embeddings"
PROMPTS = "/ml/v1/prompts"
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
class IBMWatsonXAI(BaseLLM):
"""
Class to interface with IBM watsonx.ai API for text generation and embeddings.
@ -247,10 +286,10 @@ class IBMWatsonXAI(BaseLLM):
"""
Get the request parameters for text generation.
"""
api_params = self._get_api_params(optional_params, print_verbose=print_verbose)
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
# build auth headers
api_token = api_params.get("token")
self.token = api_token
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
@ -294,118 +333,6 @@ class IBMWatsonXAI(BaseLLM):
method="POST", url=url, headers=headers, json=payload, params=request_params
)
def _get_api_params(
self,
params: dict,
print_verbose: Optional[Callable] = None,
generate_token: Optional[bool] = True,
) -> dict:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
# Load auth variables from params
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
api_key = params.pop("apikey", None)
token = params.pop("token", None)
project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
region_name = params.pop("region_name", params.pop("region", None))
if region_name is None:
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
wx_credentials = params.pop(
"wx_credentials",
params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
# Load auth variables from environment variables
if url is None:
url = (
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
if api_key is None:
api_key = (
get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
if token is None:
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
if project_id is None:
project_id = (
get_secret_str("WATSONX_PROJECT_ID")
or get_secret_str("WX_PROJECT_ID")
or get_secret_str("PROJECT_ID")
)
if region_name is None:
region_name = (
get_secret_str("WATSONX_REGION")
or get_secret_str("WX_REGION")
or get_secret_str("REGION")
)
if space_id is None:
space_id = (
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret_str("WATSONX_SPACE_ID")
or get_secret_str("WX_SPACE_ID")
or get_secret_str("SPACE_ID")
)
# credentials parsing
if wx_credentials is not None:
url = wx_credentials.get("url", url)
api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key)
)
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", token
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
# verify that all required credentials are present
if url is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
)
if token is None and api_key is not None and generate_token:
# generate the auth token
if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai")
token = self.generate_iam_token(api_key)
elif token is None and api_key is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
)
if project_id is None:
raise WatsonXAIError(
status_code=401,
url=url,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
)
return {
"url": url,
"api_key": api_key,
"token": token,
"project_id": project_id,
"space_id": space_id,
"region_name": region_name,
"api_version": api_version,
}
def _process_text_gen_response(
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
) -> ModelResponse:
@ -616,9 +543,10 @@ class IBMWatsonXAI(BaseLLM):
input = [input]
if api_key is not None:
optional_params["api_key"] = api_key
api_params = self._get_api_params(optional_params)
api_params = _get_api_params(optional_params)
# build auth headers
api_token = api_params.get("token")
self.token = api_token
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
@ -664,29 +592,9 @@ class IBMWatsonXAI(BaseLLM):
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def generate_iam_token(self, api_key=None, **params):
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
if api_key is None:
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
if api_key is None:
raise ValueError("API key is required")
headers["Accept"] = "application/json"
data = {
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
"apikey": api_key,
}
response = httpx.post(
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
)
response.raise_for_status()
json_data = response.json()
iam_access_token = json_data["access_token"]
self.token = iam_access_token
return iam_access_token
def get_available_models(self, *, ids_only: bool = True, **params):
api_params = self._get_api_params(params)
api_params = _get_api_params(params)
self.token = api_params["token"]
headers = {
"Authorization": f"Bearer {api_params['token']}",
"Content-Type": "application/json",

View file

@ -77,6 +77,7 @@ from litellm.utils import (
read_config_args,
supports_httpx_timeout,
token_counter,
validate_chat_completion_user_messages,
)
from ._logging import verbose_logger
@ -107,9 +108,9 @@ from .llms.azure_text import AzureTextCompletion
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.cohere import chat as cohere_chat
from .llms.cohere import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed
@ -157,11 +158,13 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
VertexEmbedding,
)
from .llms.watsonx import IBMWatsonXAI
from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI
from .types.llms.openai import (
ChatCompletionAssistantMessage,
ChatCompletionAudioParam,
ChatCompletionModality,
ChatCompletionPredictionContentParam,
ChatCompletionUserMessage,
HttpxBinaryResponseContent,
)
@ -211,6 +214,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding()
bedrock_image_generation = BedrockImageGeneration()
vertex_chat_completion = VertexLLM()
vertex_embedding = VertexEmbedding()
vertex_multimodal_embedding = VertexMultimodalEmbedding()
@ -220,6 +224,7 @@ vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler()
####### COMPLETION ENDPOINTS ################
@ -304,6 +309,7 @@ async def acompletion(
max_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
modalities: Optional[List[ChatCompletionModality]] = None,
prediction: Optional[ChatCompletionPredictionContentParam] = None,
audio: Optional[ChatCompletionAudioParam] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
@ -346,6 +352,7 @@ async def acompletion(
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
max_completion_tokens (integer, optional): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
modalities (List[ChatCompletionModality], optional): Output types that you would like the model to generate for this request. You can use `["text", "audio"]`
prediction (ChatCompletionPredictionContentParam, optional): Configuration for a Predicted Output, which can greatly improve response times when large parts of the model response are known ahead of time. This is most common when you are regenerating a file with only minor changes to most of the content.
audio (ChatCompletionAudioParam, optional): Parameters for audio output. Required when audio output is requested with modalities: ["audio"]
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far.
@ -387,6 +394,7 @@ async def acompletion(
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"modalities": modalities,
"prediction": prediction,
"audio": audio,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
@ -693,6 +701,7 @@ def completion( # type: ignore # noqa: PLR0915
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
modalities: Optional[List[ChatCompletionModality]] = None,
prediction: Optional[ChatCompletionPredictionContentParam] = None,
audio: Optional[ChatCompletionAudioParam] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
@ -737,6 +746,7 @@ def completion( # type: ignore # noqa: PLR0915
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
max_completion_tokens (integer, optional): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
modalities (List[ChatCompletionModality], optional): Output types that you would like the model to generate for this request.. You can use `["text", "audio"]`
prediction (ChatCompletionPredictionContentParam, optional): Configuration for a Predicted Output, which can greatly improve response times when large parts of the model response are known ahead of time. This is most common when you are regenerating a file with only minor changes to most of the content.
audio (ChatCompletionAudioParam, optional): Parameters for audio output. Required when audio output is requested with modalities: ["audio"]
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far.
@ -843,6 +853,7 @@ def completion( # type: ignore # noqa: PLR0915
"stop",
"max_completion_tokens",
"modalities",
"prediction",
"audio",
"max_tokens",
"presence_penalty",
@ -914,6 +925,9 @@ def completion( # type: ignore # noqa: PLR0915
"aws_region_name", None
) # support region-based pricing for bedrock
### VALIDATE USER MESSAGES ###
validate_chat_completion_user_messages(messages=messages)
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
@ -994,6 +1008,7 @@ def completion( # type: ignore # noqa: PLR0915
max_tokens=max_tokens,
max_completion_tokens=max_completion_tokens,
modalities=modalities,
prediction=prediction,
audio=audio,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
@ -2607,6 +2622,26 @@ def completion( # type: ignore # noqa: PLR0915
## RESPONSE OBJECT
response = response
elif custom_llm_provider == "watsonx":
response = watsonx_chat_completion.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
custom_llm_provider="watsonx",
)
elif custom_llm_provider == "watsonx_text":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonxai.completion(
model=model,
@ -3867,34 +3902,17 @@ async def atext_completion(
custom_llm_provider=custom_llm_provider,
)
else:
transformed_logprobs = None
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}")
## TRANSLATE CHAT TO TEXT FORMAT ##
## OpenAI / Azure Text Completion Returns here
if isinstance(response, TextCompletionResponse):
return response
elif asyncio.iscoroutine(response):
response = await response
text_completion_response = TextCompletionResponse()
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None)
text_completion_response["model"] = response.get("model", None)
text_choices = TextChoices()
text_choices["text"] = response["choices"][0]["message"]["content"]
text_choices["index"] = response["choices"][0]["index"]
text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None)
text_completion_response._hidden_params = HiddenParams(
**response._hidden_params
text_completion_response = litellm.utils.LiteLLMResponseObjectHandler.convert_chat_to_text_completion(
text_completion_response=text_completion_response,
response=response,
custom_llm_provider=custom_llm_provider,
)
return text_completion_response
except Exception as e:
@ -4156,29 +4174,17 @@ def text_completion( # noqa: PLR0915
return response
elif isinstance(response, TextCompletionStreamWrapper):
return response
transformed_logprobs = None
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
# OpenAI Text / Azure Text will return here
if isinstance(response, TextCompletionResponse):
return response
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None)
text_completion_response["model"] = response.get("model", None)
text_choices = TextChoices()
text_choices["text"] = response["choices"][0]["message"]["content"]
text_choices["index"] = response["choices"][0]["index"]
text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None)
text_completion_response._hidden_params = HiddenParams(**response._hidden_params)
text_completion_response = (
litellm.utils.LiteLLMResponseObjectHandler.convert_chat_to_text_completion(
response=response,
text_completion_response=text_completion_response,
)
)
return text_completion_response
@ -4314,9 +4320,9 @@ async def amoderation(
else:
_openai_client = openai_client
if model is not None:
response = await openai_client.moderations.create(input=input, model=model)
response = await _openai_client.moderations.create(input=input, model=model)
else:
response = await openai_client.moderations.create(input=input)
response = await _openai_client.moderations.create(input=input)
return response
@ -4442,6 +4448,7 @@ def image_generation( # noqa: PLR0915
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(
model=model,
n=n,
quality=quality,
response_format=response_format,
@ -4534,7 +4541,7 @@ def image_generation( # noqa: PLR0915
elif custom_llm_provider == "bedrock":
if model is None:
raise Exception("Model needs to be set for bedrock")
model_response = bedrock_image_generation.image_generation(
model_response = bedrock_image_generation.image_generation( # type: ignore
model=model,
prompt=prompt,
timeout=timeout,

View file

@ -80,6 +80,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true,
"supports_prompt_caching": true
},
@ -94,6 +95,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true,
"supports_prompt_caching": true
},
@ -108,7 +110,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_vision": false,
"supports_prompt_caching": true
},
"o1-mini-2024-09-12": {
@ -122,7 +124,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_vision": false,
"supports_prompt_caching": true
},
"o1-preview": {
@ -136,7 +138,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_vision": false,
"supports_prompt_caching": true
},
"o1-preview-2024-09-12": {
@ -150,7 +152,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_vision": false,
"supports_prompt_caching": true
},
"chatgpt-4o-latest": {
@ -190,6 +192,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true,
"supports_prompt_caching": true
},
@ -461,6 +464,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true
},
"ft:gpt-4o-mini-2024-07-18": {
@ -473,6 +477,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true
},
"ft:davinci-002": {
@ -652,7 +657,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_vision": false,
"supports_prompt_caching": true
},
"azure/o1-mini-2024-09-12": {
@ -666,7 +671,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_vision": false,
"supports_prompt_caching": true
},
"azure/o1-preview": {
@ -680,7 +685,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_vision": false,
"supports_prompt_caching": true
},
"azure/o1-preview-2024-09-12": {
@ -694,7 +699,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true,
"supports_vision": false,
"supports_prompt_caching": true
},
"azure/gpt-4o": {
@ -721,6 +726,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true
},
"azure/gpt-4o-2024-05-13": {
@ -746,6 +752,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true
},
"azure/global-standard/gpt-4o-mini": {
@ -758,6 +765,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true
},
"azure/gpt-4o-mini": {
@ -771,6 +779,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true,
"supports_prompt_caching": true
},
@ -785,6 +794,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_response_schema": true,
"supports_vision": true,
"supports_prompt_caching": true
},
@ -1109,6 +1119,52 @@
"supports_function_calling": true,
"mode": "chat"
},
"azure_ai/mistral-large-2407": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "azure_ai",
"supports_function_calling": true,
"mode": "chat",
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview"
},
"azure_ai/ministral-3b": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000004,
"output_cost_per_token": 0.00000004,
"litellm_provider": "azure_ai",
"supports_function_calling": true,
"mode": "chat",
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.ministral-3b-2410-offer?tab=Overview"
},
"azure_ai/Llama-3.2-11B-Vision-Instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 2048,
"input_cost_per_token": 0.00000037,
"output_cost_per_token": 0.00000037,
"litellm_provider": "azure_ai",
"supports_function_calling": true,
"supports_vision": true,
"mode": "chat",
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/metagenai.meta-llama-3-2-11b-vision-instruct-offer?tab=Overview"
},
"azure_ai/Llama-3.2-90B-Vision-Instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 2048,
"input_cost_per_token": 0.00000204,
"output_cost_per_token": 0.00000204,
"litellm_provider": "azure_ai",
"supports_function_calling": true,
"supports_vision": true,
"mode": "chat",
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/metagenai.meta-llama-3-2-90b-vision-instruct-offer?tab=Overview"
},
"azure_ai/Meta-Llama-3-70B-Instruct": {
"max_tokens": 8192,
"max_input_tokens": 8192,
@ -1148,6 +1204,105 @@
"mode": "chat",
"source":"https://azuremarketplace.microsoft.com/en-us/marketplace/apps/metagenai.meta-llama-3-1-405b-instruct-offer?tab=PlansAndPrice"
},
"azure_ai/Phi-3.5-mini-instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.00000052,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": false,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/Phi-3.5-vision-instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.00000052,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": true,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/Phi-3.5-MoE-instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000016,
"output_cost_per_token": 0.00000064,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": false,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/Phi-3-mini-4k-instruct": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.00000052,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": false,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/Phi-3-mini-128k-instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.00000052,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": false,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/Phi-3-small-8k-instruct": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.0000006,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": false,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/Phi-3-small-128k-instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.0000006,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": false,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/Phi-3-medium-4k-instruct": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000017,
"output_cost_per_token": 0.00000068,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": false,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/Phi-3-medium-128k-instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000017,
"output_cost_per_token": 0.00000068,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_vision": false,
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
},
"azure_ai/cohere-rerank-v3-multilingual": {
"max_tokens": 4096,
"max_input_tokens": 4096,
@ -1730,6 +1885,22 @@
"supports_assistant_prefill": true,
"supports_prompt_caching": true
},
"claude-3-5-haiku-20241022": {
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"cache_creation_input_token_cost": 0.00000125,
"cache_read_input_token_cost": 0.0000001,
"litellm_provider": "anthropic",
"mode": "chat",
"supports_function_calling": true,
"tool_use_system_prompt_tokens": 264,
"supports_assistant_prefill": true,
"supports_prompt_caching": true,
"supports_pdf_input": true
},
"claude-3-opus-20240229": {
"max_tokens": 4096,
"max_input_tokens": 200000,
@ -2195,16 +2366,16 @@
"input_cost_per_image": 0.00032875,
"input_cost_per_audio_per_second": 0.00003125,
"input_cost_per_video_per_second": 0.00032875,
"input_cost_per_token": 0.000000078125,
"input_cost_per_character": 0.0000003125,
"input_cost_per_token": 0.00000125,
"input_cost_per_character": 0.0000003125,
"input_cost_per_image_above_128k_tokens": 0.0006575,
"input_cost_per_video_per_second_above_128k_tokens": 0.0006575,
"input_cost_per_audio_per_second_above_128k_tokens": 0.0000625,
"input_cost_per_token_above_128k_tokens": 0.00000015625,
"input_cost_per_character_above_128k_tokens": 0.000000625,
"output_cost_per_token": 0.0000003125,
"input_cost_per_token_above_128k_tokens": 0.0000025,
"input_cost_per_character_above_128k_tokens": 0.000000625,
"output_cost_per_token": 0.000005,
"output_cost_per_character": 0.00000125,
"output_cost_per_token_above_128k_tokens": 0.000000625,
"output_cost_per_token_above_128k_tokens": 0.00001,
"output_cost_per_character_above_128k_tokens": 0.0000025,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
@ -2221,16 +2392,16 @@
"input_cost_per_image": 0.00032875,
"input_cost_per_audio_per_second": 0.00003125,
"input_cost_per_video_per_second": 0.00032875,
"input_cost_per_token": 0.000000078125,
"input_cost_per_character": 0.0000003125,
"input_cost_per_token": 0.00000125,
"input_cost_per_character": 0.0000003125,
"input_cost_per_image_above_128k_tokens": 0.0006575,
"input_cost_per_video_per_second_above_128k_tokens": 0.0006575,
"input_cost_per_audio_per_second_above_128k_tokens": 0.0000625,
"input_cost_per_token_above_128k_tokens": 0.00000015625,
"input_cost_per_character_above_128k_tokens": 0.000000625,
"output_cost_per_token": 0.0000003125,
"input_cost_per_token_above_128k_tokens": 0.0000025,
"input_cost_per_character_above_128k_tokens": 0.000000625,
"output_cost_per_token": 0.000005,
"output_cost_per_character": 0.00000125,
"output_cost_per_token_above_128k_tokens": 0.000000625,
"output_cost_per_token_above_128k_tokens": 0.00001,
"output_cost_per_character_above_128k_tokens": 0.0000025,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
@ -2247,16 +2418,16 @@
"input_cost_per_image": 0.00032875,
"input_cost_per_audio_per_second": 0.00003125,
"input_cost_per_video_per_second": 0.00032875,
"input_cost_per_token": 0.000000078125,
"input_cost_per_character": 0.0000003125,
"input_cost_per_token": 0.00000125,
"input_cost_per_character": 0.0000003125,
"input_cost_per_image_above_128k_tokens": 0.0006575,
"input_cost_per_video_per_second_above_128k_tokens": 0.0006575,
"input_cost_per_audio_per_second_above_128k_tokens": 0.0000625,
"input_cost_per_token_above_128k_tokens": 0.00000015625,
"input_cost_per_character_above_128k_tokens": 0.000000625,
"output_cost_per_token": 0.0000003125,
"input_cost_per_token_above_128k_tokens": 0.0000025,
"input_cost_per_character_above_128k_tokens": 0.000000625,
"output_cost_per_token": 0.000005,
"output_cost_per_character": 0.00000125,
"output_cost_per_token_above_128k_tokens": 0.000000625,
"output_cost_per_token_above_128k_tokens": 0.00001,
"output_cost_per_character_above_128k_tokens": 0.0000025,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
@ -2356,17 +2527,17 @@
"input_cost_per_image": 0.00002,
"input_cost_per_video_per_second": 0.00002,
"input_cost_per_audio_per_second": 0.000002,
"input_cost_per_token": 0.000000004688,
"input_cost_per_token": 0.000000075,
"input_cost_per_character": 0.00000001875,
"input_cost_per_token_above_128k_tokens": 0.000001,
"input_cost_per_character_above_128k_tokens": 0.00000025,
"input_cost_per_image_above_128k_tokens": 0.00004,
"input_cost_per_video_per_second_above_128k_tokens": 0.00004,
"input_cost_per_audio_per_second_above_128k_tokens": 0.000004,
"output_cost_per_token": 0.0000000046875,
"output_cost_per_character": 0.00000001875,
"output_cost_per_token_above_128k_tokens": 0.000000009375,
"output_cost_per_character_above_128k_tokens": 0.0000000375,
"output_cost_per_token": 0.0000003,
"output_cost_per_character": 0.000000075,
"output_cost_per_token_above_128k_tokens": 0.0000006,
"output_cost_per_character_above_128k_tokens": 0.00000015,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_system_messages": true,
@ -2420,17 +2591,17 @@
"input_cost_per_image": 0.00002,
"input_cost_per_video_per_second": 0.00002,
"input_cost_per_audio_per_second": 0.000002,
"input_cost_per_token": 0.000000004688,
"input_cost_per_token": 0.000000075,
"input_cost_per_character": 0.00000001875,
"input_cost_per_token_above_128k_tokens": 0.000001,
"input_cost_per_character_above_128k_tokens": 0.00000025,
"input_cost_per_image_above_128k_tokens": 0.00004,
"input_cost_per_video_per_second_above_128k_tokens": 0.00004,
"input_cost_per_audio_per_second_above_128k_tokens": 0.000004,
"output_cost_per_token": 0.0000000046875,
"output_cost_per_character": 0.00000001875,
"output_cost_per_token_above_128k_tokens": 0.000000009375,
"output_cost_per_character_above_128k_tokens": 0.0000000375,
"output_cost_per_token": 0.0000003,
"output_cost_per_character": 0.000000075,
"output_cost_per_token_above_128k_tokens": 0.0000006,
"output_cost_per_character_above_128k_tokens": 0.00000015,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_system_messages": true,
@ -2452,17 +2623,17 @@
"input_cost_per_image": 0.00002,
"input_cost_per_video_per_second": 0.00002,
"input_cost_per_audio_per_second": 0.000002,
"input_cost_per_token": 0.000000004688,
"input_cost_per_token": 0.000000075,
"input_cost_per_character": 0.00000001875,
"input_cost_per_token_above_128k_tokens": 0.000001,
"input_cost_per_character_above_128k_tokens": 0.00000025,
"input_cost_per_image_above_128k_tokens": 0.00004,
"input_cost_per_video_per_second_above_128k_tokens": 0.00004,
"input_cost_per_audio_per_second_above_128k_tokens": 0.000004,
"output_cost_per_token": 0.0000000046875,
"output_cost_per_character": 0.00000001875,
"output_cost_per_token_above_128k_tokens": 0.000000009375,
"output_cost_per_character_above_128k_tokens": 0.0000000375,
"output_cost_per_token": 0.0000003,
"output_cost_per_character": 0.000000075,
"output_cost_per_token_above_128k_tokens": 0.0000006,
"output_cost_per_character_above_128k_tokens": 0.00000015,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_system_messages": true,
@ -2484,7 +2655,7 @@
"input_cost_per_image": 0.00002,
"input_cost_per_video_per_second": 0.00002,
"input_cost_per_audio_per_second": 0.000002,
"input_cost_per_token": 0.000000004688,
"input_cost_per_token": 0.000000075,
"input_cost_per_character": 0.00000001875,
"input_cost_per_token_above_128k_tokens": 0.000001,
"input_cost_per_character_above_128k_tokens": 0.00000025,
@ -2643,6 +2814,17 @@
"supports_vision": true,
"supports_assistant_prefill": true
},
"vertex_ai/claude-3-5-haiku@20241022": {
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true
},
"vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096,
"max_input_tokens": 200000,
@ -2686,14 +2868,15 @@
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/meta/llama-3.2-90b-vision-instruct-maas": {
"max_tokens": 8192,
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"max_output_tokens": 2048,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"supports_system_messages": true,
"supports_vision": true,
"source": "https://console.cloud.google.com/vertex-ai/publishers/meta/model-garden/llama-3.2-90b-vision-instruct-maas"
},
"vertex_ai/mistral-large@latest": {
@ -3615,6 +3798,14 @@
"supports_function_calling": true,
"supports_vision": true
},
"openrouter/anthropic/claude-3-5-haiku": {
"max_tokens": 200000,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true
},
"openrouter/anthropic/claude-3-haiku-20240307": {
"max_tokens": 4096,
"max_input_tokens": 200000,
@ -3627,6 +3818,17 @@
"supports_vision": true,
"tool_use_system_prompt_tokens": 264
},
"openrouter/anthropic/claude-3-5-haiku-20241022": {
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"tool_use_system_prompt_tokens": 264
},
"anthropic/claude-3-5-sonnet-20241022": {
"max_tokens": 8192,
"max_input_tokens": 200000,
@ -3747,7 +3949,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
"supports_vision": false
},
"openrouter/openai/o1-mini-2024-09-12": {
"max_tokens": 65536,
@ -3759,7 +3961,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
"supports_vision": false
},
"openrouter/openai/o1-preview": {
"max_tokens": 32768,
@ -3771,7 +3973,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
"supports_vision": false
},
"openrouter/openai/o1-preview-2024-09-12": {
"max_tokens": 32768,
@ -3783,7 +3985,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
"supports_vision": false
},
"openrouter/openai/gpt-4o": {
"max_tokens": 4096,
@ -4330,9 +4532,9 @@
"supports_vision": true
},
"anthropic.claude-3-5-sonnet-20241022-v2:0": {
"max_tokens": 4096,
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -4352,6 +4554,17 @@
"supports_function_calling": true,
"supports_vision": true
},
"anthropic.claude-3-5-haiku-20241022-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_function_calling": true
},
"anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
@ -4386,9 +4599,9 @@
"supports_vision": true
},
"us.anthropic.claude-3-5-sonnet-20241022-v2:0": {
"max_tokens": 4096,
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -4408,6 +4621,17 @@
"supports_function_calling": true,
"supports_vision": true
},
"us.anthropic.claude-3-5-haiku-20241022-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_function_calling": true
},
"us.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
@ -4442,9 +4666,9 @@
"supports_vision": true
},
"eu.anthropic.claude-3-5-sonnet-20241022-v2:0": {
"max_tokens": 4096,
"max_tokens": 8192,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -4464,6 +4688,16 @@
"supports_function_calling": true,
"supports_vision": true
},
"eu.anthropic.claude-3-5-haiku-20241022-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000005,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_function_calling": true
},
"eu.anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
@ -5378,6 +5612,13 @@
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.sd3-large-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.08,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"max_input_tokens": 4096,

View file

@ -1,5 +1,5 @@
model_list:
- model_name: claude-3-5-sonnet-20240620
- model_name: "*"
litellm_params:
model: claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
@ -10,63 +10,113 @@ model_list:
output_cost_per_token: 0.000015 # 15$/M
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
api_key: my-fake-key
- model_name: gemini-1.5-flash-002
- model_name: fake-openai-endpoint-2
litellm_params:
model: gemini/gemini-1.5-flash-002
# litellm_settings:
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
# callbacks: ["otel", "prometheus"]
# default_redis_batch_cache_expiry: 10
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
stream_timeout: 0.001
timeout: 1
rpm: 1
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
## bedrock chat completions
- model_name: "*anthropic.claude*"
litellm_params:
model: bedrock/*anthropic.claude*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
guardrailConfig:
"guardrailIdentifier": "h4dsqwhp6j66"
"guardrailVersion": "2"
"trace": "enabled"
## bedrock embeddings
- model_name: "*amazon.titan-embed-*"
litellm_params:
model: bedrock/amazon.titan-embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "*cohere.embed-*"
litellm_params:
model: bedrock/cohere.embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "bedrock/*"
litellm_params:
model: bedrock/*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
rpm: 480
timeout: 300
stream_timeout: 60
litellm_settings:
cache: True
cache_params:
type: redis
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
callbacks: ["otel", "prometheus"]
default_redis_batch_cache_expiry: 10
# default_team_settings:
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
# success_callback: ["langfuse"]
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
# disable caching on the actual API call
supported_call_types: []
# litellm_settings:
# cache: True
# cache_params:
# type: redis
# see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
host: os.environ/REDIS_HOST
port: os.environ/REDIS_PORT
password: os.environ/REDIS_PASSWORD
# # disable caching on the actual API call
# supported_call_types: []
# see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
# see https://docs.litellm.ai/docs/proxy/prometheus
callbacks: ['prometheus', 'otel']
# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
# host: os.environ/REDIS_HOST
# port: os.environ/REDIS_PORT
# password: os.environ/REDIS_PASSWORD
# # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry
failure_callback: ['sentry']
service_callback: ['prometheus_system']
# redact_user_api_key_info: true
# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
# # see https://docs.litellm.ai/docs/proxy/prometheus
# callbacks: ['otel']
router_settings:
routing_strategy: latency-based-routing
routing_strategy_args:
# only assign 40% of traffic to the fastest deployment to avoid overloading it
lowest_latency_buffer: 0.4
# # router_settings:
# # routing_strategy: latency-based-routing
# # routing_strategy_args:
# # # only assign 40% of traffic to the fastest deployment to avoid overloading it
# # lowest_latency_buffer: 0.4
# consider last five minutes of calls for latency calculation
ttl: 300
redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT
redis_password: os.environ/REDIS_PASSWORD
# # # consider last five minutes of calls for latency calculation
# # ttl: 300
# # redis_host: os.environ/REDIS_HOST
# # redis_port: os.environ/REDIS_PORT
# # redis_password: os.environ/REDIS_PASSWORD
# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
# # general_settings:
# # master_key: os.environ/LITELLM_MASTER_KEY
# # database_url: os.environ/DATABASE_URL
# # disable_master_key_return: true
# # # alerting: ['slack', 'email']
# # alerting: ['email']
# see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
general_settings:
master_key: os.environ/LITELLM_MASTER_KEY
database_url: os.environ/DATABASE_URL
disable_master_key_return: true
# alerting: ['slack', 'email']
alerting: ['email']
# # # Batch write spend updates every 60s
# # proxy_batch_write_at: 60
# Batch write spend updates every 60s
proxy_batch_write_at: 60
# see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
# our api keys rarely change
user_api_key_cache_ttl: 3600
# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
# # # our api keys rarely change
# # user_api_key_cache_ttl: 3600

View file

@ -436,15 +436,7 @@ class LiteLLM_JWTAuth(LiteLLMBase):
"""
admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[
Literal[
"openai_routes",
"info_routes",
"management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
]
] = [
admin_allowed_routes: List[str] = [
"management_routes",
"spend_tracking_routes",
"global_spend_tracking_routes",
@ -1902,6 +1894,7 @@ class ProxyErrorTypes(str, enum.Enum):
auth_error = "auth_error"
internal_server_error = "internal_server_error"
bad_request_error = "bad_request_error"
not_found_error = "not_found_error"
class SSOUserDefinedValues(TypedDict):

View file

@ -13,11 +13,13 @@ import traceback
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Literal, Optional
import httpx
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.proxy._types import (
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
@ -30,7 +32,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from .auth_checks_organization import organization_role_based_access_check
@ -42,6 +44,10 @@ if TYPE_CHECKING:
else:
Span = Any
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
db_cache_expiry = 5 # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
@ -284,7 +290,7 @@ def get_actual_routes(allowed_routes: list) -> list:
return actual_routes
@log_to_opentelemetry
@log_db_metrics
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
@ -383,7 +389,33 @@ def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
return False
@log_to_opentelemetry
def _should_check_db(
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
) -> bool:
"""
Prevent calling db repeatedly for items that don't exist in the db.
"""
current_time = time.time()
# if key doesn't exist in last_db_access_time -> check db
if key not in last_db_access_time:
return True
elif (
last_db_access_time[key][0] is not None
): # check db for non-null values (for refresh operations)
return True
elif last_db_access_time[key][0] is None:
if current_time - last_db_access_time[key] >= db_cache_expiry:
return True
return False
def _update_last_db_access_time(
key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
):
last_db_access_time[key] = (value, time.time())
@log_db_metrics
async def get_user_object(
user_id: str,
prisma_client: Optional[PrismaClient],
@ -412,11 +444,20 @@ async def get_user_object(
if prisma_client is None:
raise Exception("No db connected")
try:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
db_access_time_key = "user_id:{}".format(user_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
)
else:
response = None
if response is None:
if user_id_upsert:
response = await prisma_client.db.litellm_usertable.create(
@ -444,6 +485,13 @@ async def get_user_object(
# save the user object to cache
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=response_dict,
last_db_access_time=last_db_access_time,
)
return _response
except Exception as e: # if user not in db
raise ValueError(
@ -514,7 +562,13 @@ async def _delete_cache_key_object(
)
@log_to_opentelemetry
@log_db_metrics
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
@ -544,7 +598,7 @@ async def get_team_object(
):
cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
key=key
key=key, parent_otel_span=parent_otel_span
)
)
@ -564,9 +618,18 @@ async def get_team_object(
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
db_access_time_key = "team_id:{}".format(team_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
)
else:
response = None
if response is None:
raise Exception
@ -580,6 +643,14 @@ async def get_team_object(
proxy_logging_obj=proxy_logging_obj,
)
# save to db access time
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=_response,
last_db_access_time=last_db_access_time,
)
return _response
except Exception:
raise Exception(
@ -587,7 +658,7 @@ async def get_team_object(
)
@log_to_opentelemetry
@log_db_metrics
async def get_key_object(
hashed_token: str,
prisma_client: Optional[PrismaClient],
@ -608,16 +679,16 @@ async def get_key_object(
# check if in cache
key = hashed_token
cached_team_obj: Optional[UserAPIKeyAuth] = None
if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
key=key
)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return UserAPIKeyAuth(**cached_team_obj)
elif isinstance(cached_team_obj, UserAPIKeyAuth):
return cached_team_obj
if cached_key_obj is not None:
if isinstance(cached_key_obj, dict):
return UserAPIKeyAuth(**cached_key_obj)
elif isinstance(cached_key_obj, UserAPIKeyAuth):
return cached_key_obj
if check_cache_only:
raise Exception(
@ -647,13 +718,55 @@ async def get_key_object(
)
return _response
except httpx.ConnectError as e:
return await _handle_failed_db_connection_for_get_key_object(e=e)
except Exception:
raise Exception(
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
)
@log_to_opentelemetry
async def _handle_failed_db_connection_for_get_key_object(
e: Exception,
) -> UserAPIKeyAuth:
"""
Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB
Use this if you don't want failed DB queries to block LLM API reqiests
Returns:
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
Raises:
- Orignal Exception in all other cases
"""
from litellm.proxy.proxy_server import (
general_settings,
litellm_proxy_admin_name,
proxy_logging_obj,
)
# If this flag is on, requests failing to connect to the DB will be allowed
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
# log this as a DB failure on prometheus
proxy_logging_obj.service_logging_obj.service_failure_hook(
service=ServiceTypes.DB,
call_type="get_key_object",
error=e,
duration=0.0,
)
return UserAPIKeyAuth(
key_name="failed-to-connect-to-db",
token="failed-to-connect-to-db",
user_id=litellm_proxy_admin_name,
)
else:
# raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus
raise e
@log_db_metrics
async def get_org_object(
org_id: str,
prisma_client: Optional[PrismaClient],

View file

@ -5,6 +5,9 @@ import json
import os
import traceback
from datetime import datetime
from typing import Optional
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler
@ -21,7 +24,7 @@ class LicenseCheck:
def __init__(self) -> None:
self.license_str = os.getenv("LITELLM_LICENSE", None)
verbose_proxy_logger.debug("License Str value - {}".format(self.license_str))
self.http_handler = HTTPHandler()
self.http_handler = HTTPHandler(timeout=15)
self.public_key = None
self.read_public_key()
@ -44,23 +47,46 @@ class LicenseCheck:
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
def _verify(self, license_str: str) -> bool:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
self.base_url, license_str
)
)
url = "{}/verify_license/{}".format(self.base_url, license_str)
response: Optional[httpx.Response] = None
try: # don't impact user, if call fails
response = self.http_handler.get(url=url)
num_retries = 3
for i in range(num_retries):
try:
response = self.http_handler.get(url=url)
if response is None:
raise Exception("No response from license server")
response.raise_for_status()
except httpx.HTTPStatusError:
if i == num_retries - 1:
raise
response.raise_for_status()
if response is None:
raise Exception("No response from license server")
response_json = response.json()
premium = response_json["verify"]
assert isinstance(premium, bool)
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
license_str, premium
)
)
return premium
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License via api. - {}".format(
str(e)
verbose_proxy_logger.exception(
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
license_str, str(e)
)
)
return False
@ -72,7 +98,7 @@ class LicenseCheck:
"""
try:
verbose_proxy_logger.debug(
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - {}".format(
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
self.license_str
)
)

View file

@ -44,14 +44,8 @@ class RouteChecks:
route in LiteLLMRoutes.info_routes.value
): # check if user allowed to call an info route
if route == "/key/info":
# check if user can access this route
query_params = request.query_params
key = query_params.get("key")
if key is not None and hash_token(token=key) != api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="user not allowed to access this key's info",
)
# handled by function itself
pass
elif route == "/user/info":
# check if user can access this route
query_params = request.query_params

View file

@ -58,7 +58,7 @@ from litellm.proxy.auth.auth_checks import (
get_org_object,
get_team_object,
get_user_object,
log_to_opentelemetry,
log_db_metrics,
)
from litellm.proxy.auth.auth_utils import (
_get_request_ip_address,
@ -703,12 +703,17 @@ async def user_api_key_auth( # noqa: PLR0915
)
if is_master_key_valid:
_user_api_key_obj = UserAPIKeyAuth(
api_key=master_key,
_user_api_key_obj = _return_user_api_key_auth_obj(
user_obj=None,
user_role=LitellmUserRoles.PROXY_ADMIN,
user_id=litellm_proxy_admin_name,
api_key=master_key,
parent_otel_span=parent_otel_span,
**end_user_params,
valid_token_dict={
**end_user_params,
"user_id": litellm_proxy_admin_name,
},
route=route,
start_time=start_time,
)
await _cache_key_object(
hashed_token=hash_token(master_key),
@ -1127,11 +1132,13 @@ async def user_api_key_auth( # noqa: PLR0915
api_key = valid_token.token
# Add hashed token to cache
await _cache_key_object(
hashed_token=api_key,
user_api_key_obj=valid_token,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
asyncio.create_task(
_cache_key_object(
hashed_token=api_key,
user_api_key_obj=valid_token,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
)
valid_token_dict = valid_token.model_dump(exclude_none=True)
@ -1227,6 +1234,7 @@ def _return_user_api_key_auth_obj(
valid_token_dict: dict,
route: str,
start_time: datetime,
user_role: Optional[LitellmUserRoles] = None,
) -> UserAPIKeyAuth:
end_time = datetime.now()
user_api_key_service_logger_obj.service_success_hook(
@ -1238,7 +1246,7 @@ def _return_user_api_key_auth_obj(
parent_otel_span=parent_otel_span,
)
retrieved_user_role = (
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
)
user_api_key_kwargs = {

View file

@ -3,18 +3,25 @@ import os
from litellm._logging import verbose_proxy_logger
LITELLM_SALT_KEY = os.getenv("LITELLM_SALT_KEY", None)
verbose_proxy_logger.debug(
"LITELLM_SALT_KEY is None using master_key to encrypt/decrypt secrets stored in DB"
)
def _get_salt_key():
from litellm.proxy.proxy_server import master_key
salt_key = os.getenv("LITELLM_SALT_KEY", None)
if salt_key is None:
verbose_proxy_logger.debug(
"LITELLM_SALT_KEY is None using master_key to encrypt/decrypt secrets stored in DB"
)
salt_key = master_key
return salt_key
def encrypt_value_helper(value: str):
from litellm.proxy.proxy_server import master_key
signing_key = LITELLM_SALT_KEY
if LITELLM_SALT_KEY is None:
signing_key = master_key
signing_key = _get_salt_key()
try:
if isinstance(value, str):
@ -35,9 +42,7 @@ def encrypt_value_helper(value: str):
def decrypt_value_helper(value: str):
from litellm.proxy.proxy_server import master_key
signing_key = LITELLM_SALT_KEY
if LITELLM_SALT_KEY is None:
signing_key = master_key
signing_key = _get_salt_key()
try:
if isinstance(value, str):

View file

@ -0,0 +1,138 @@
"""
Handles logging DB success/failure to ServiceLogger()
ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
"""
from datetime import datetime
from functools import wraps
from typing import Callable, Dict, Tuple
from litellm._service_logger import ServiceTypes
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
def log_db_metrics(func):
"""
Decorator to log the duration of a DB related function to ServiceLogger()
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
Args:
func: The function to be decorated
Returns:
Result from the decorated function
Raises:
Exception: If the decorated function raises an exception
"""
@wraps(func)
async def wrapper(*args, **kwargs):
from prisma.errors import PrismaError
start_time: datetime = datetime.now()
try:
result = await func(*args, **kwargs)
end_time: datetime = datetime.now()
from litellm.proxy.proxy_server import proxy_logging_obj
if "PROXY" not in func.__name__:
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span", None),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)
elif (
# in litellm custom callbacks kwargs is passed as arg[0]
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
args is not None
and len(args) > 0
and isinstance(args[0], dict)
):
passed_kwargs = args[0]
parent_otel_span = _get_parent_otel_span_from_kwargs(
kwargs=passed_kwargs
)
if parent_otel_span is not None:
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.BATCH_WRITE_TO_DB,
call_type=func.__name__,
parent_otel_span=parent_otel_span,
duration=0.0,
start_time=start_time,
end_time=end_time,
event_metadata=metadata,
)
# end of logging to otel
return result
except Exception as e:
end_time: datetime = datetime.now()
await _handle_logging_db_exception(
e=e,
func=func,
kwargs=kwargs,
args=args,
start_time=start_time,
end_time=end_time,
)
raise e
return wrapper
def _is_exception_related_to_db(e: Exception) -> bool:
"""
Returns True if the exception is related to the DB
"""
import httpx
from prisma.errors import PrismaError
return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
async def _handle_logging_db_exception(
e: Exception,
func: Callable,
kwargs: Dict,
args: Tuple,
start_time: datetime,
end_time: datetime,
) -> None:
from litellm.proxy.proxy_server import proxy_logging_obj
# don't log this as a DB Service Failure, if the DB did not raise an exception
if _is_exception_related_to_db(e) is not True:
return
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
error=e,
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span"),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)

View file

@ -1,4 +1,5 @@
import copy
import time
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from fastapi import Request
@ -6,6 +7,7 @@ from starlette.datastructures import Headers
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging
from litellm.proxy._types import (
AddTeamCallback,
CommonProxyErrors,
@ -16,11 +18,15 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_utils import get_request_route
from litellm.types.services import ServiceTypes
from litellm.types.utils import (
StandardLoggingUserAPIKeyMetadata,
SupportedCacheControls,
)
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
if TYPE_CHECKING:
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
@ -471,7 +477,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
start_time = time.time()
## [Enterprise Only]
# Add User-IP Address
requester_ip_address = ""
@ -539,6 +545,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
verbose_proxy_logger.debug(
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
)
end_time = time.time()
await service_logger_obj.async_service_success_hook(
service=ServiceTypes.PROXY_PRE_CALL,
duration=end_time - start_time,
call_type="add_litellm_data_to_request",
start_time=start_time,
end_time=end_time,
parent_otel_span=user_api_key_dict.parent_otel_span,
)
return data

View file

@ -32,7 +32,7 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
from litellm.secret_managers.main import get_secret
router = APIRouter()
@ -734,13 +734,37 @@ async def info_key_fn(
raise Exception(
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
if key is None:
key = user_api_key_dict.api_key
key_info = await prisma_client.get_data(token=key)
# default to using Auth token if no key is passed in
key = key or user_api_key_dict.api_key
hashed_key: Optional[str] = key
if key is not None:
hashed_key = _hash_token_if_needed(token=key)
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_key}, # type: ignore
include={"litellm_budget_table": True},
)
if key_info is None:
raise ProxyException(
message="Key not found in database",
type=ProxyErrorTypes.not_found_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
if (
_can_user_query_key_info(
user_api_key_dict=user_api_key_dict,
key=key,
key_info=key_info,
)
is not True
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"message": "No keys found"},
status_code=status.HTTP_403_FORBIDDEN,
detail="You are not allowed to access this key's info. Your role={}".format(
user_api_key_dict.user_role
),
)
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
try:
@ -1540,6 +1564,27 @@ async def key_health(
)
def _can_user_query_key_info(
user_api_key_dict: UserAPIKeyAuth,
key: Optional[str],
key_info: LiteLLM_VerificationToken,
) -> bool:
"""
Helper to check if the user has access to the key's info
"""
if (
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
):
return True
elif user_api_key_dict.api_key == key:
return True
# user can query their own key info
elif key_info.user_id == user_api_key_dict.user_id:
return True
return False
async def test_key_logging(
user_api_key_dict: UserAPIKeyAuth,
request: Request,
@ -1599,7 +1644,9 @@ async def test_key_logging(
details=f"Logging test failed: {str(e)}",
)
await asyncio.sleep(1) # wait for callbacks to run
await asyncio.sleep(
2
) # wait for callbacks to run, callbacks use batching so wait for the flush event
# Check if any logger exceptions were triggered
log_contents = log_capture_string.getvalue()

View file

@ -1281,12 +1281,20 @@ async def list_team(
where={"team_id": team.team_id}
)
returned_responses.append(
TeamListResponseObject(
**team.model_dump(),
team_memberships=_team_memberships,
keys=keys,
try:
returned_responses.append(
TeamListResponseObject(
**team.model_dump(),
team_memberships=_team_memberships,
keys=keys,
)
)
)
except Exception as e:
team_exception = """Invalid team object for team_id: {}. team_object={}.
Error: {}
""".format(
team.team_id, team.model_dump(), str(e)
)
raise HTTPException(status_code=400, detail={"error": team_exception})
return returned_responses

View file

@ -694,6 +694,9 @@ def run_server( # noqa: PLR0915
import litellm
if detailed_debug is True:
litellm._turn_on_debug()
# DO NOT DELETE - enables global variables to work across files
from litellm.proxy.proxy_server import app # noqa

View file

@ -1,11 +1,12 @@
model_list:
- model_name: gpt-4o
- model_name: fake-openai-endpoint
litellm_params:
model: openai/gpt-5
model: openai/fake
api_key: os.environ/OPENAI_API_KEY
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
alerting: ["slack"]
alerting_threshold: 0.001
litellm_settings:
callbacks: ["gcs_bucket"]

View file

@ -125,7 +125,7 @@ from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
router as analytics_router,
)
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
from litellm.proxy.auth.auth_checks import log_db_metrics
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
@ -1198,7 +1198,7 @@ async def update_cache( # noqa: PLR0915
await _update_team_cache()
asyncio.create_task(
user_api_key_cache.async_batch_set_cache(
user_api_key_cache.async_set_cache_pipeline(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=parent_otel_span,
@ -1257,7 +1257,7 @@ class ProxyConfig:
"""
def __init__(self) -> None:
pass
self.config: Dict[str, Any] = {}
def is_yaml(self, config_file_path: str) -> bool:
if not os.path.isfile(config_file_path):
@ -1271,9 +1271,6 @@ class ProxyConfig:
) -> dict:
"""
Given a config file path, load the config from the file.
If `store_model_in_db` is True, then read the DB and update the config with the DB values.
Args:
config_file_path (str): path to the config file
Returns:
@ -1299,40 +1296,6 @@ class ProxyConfig:
"litellm_settings": {},
}
## DB
if prisma_client is not None and (
general_settings.get("store_model_in_db", False) is True
or store_model_in_db is True
):
_tasks = []
keys = [
"general_settings",
"router_settings",
"litellm_settings",
"environment_variables",
]
for k in keys:
response = prisma_client.get_generic_data(
key="param_name", value=k, table_name="config"
)
_tasks.append(response)
responses = await asyncio.gather(*_tasks)
for response in responses:
if response is not None:
param_name = getattr(response, "param_name", None)
param_value = getattr(response, "param_value", None)
if param_name is not None and param_value is not None:
# check if param_name is already in the config
if param_name in config:
if isinstance(config[param_name], dict):
config[param_name].update(param_value)
else:
config[param_name] = param_value
else:
# if it's not in the config - then add it
config[param_name] = param_value
return config
async def save_config(self, new_config: dict):
@ -1398,8 +1361,10 @@ class ProxyConfig:
- for a given team id
- return the relevant completion() call params
"""
# load existing config
config = await self.get_config()
config = self.config
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get("litellm_settings", {})
all_teams_config = litellm_settings.get("default_team_settings", None)
@ -1451,7 +1416,9 @@ class ProxyConfig:
dict: config
"""
global prisma_client, store_model_in_db
# Load existing config
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
@ -1473,12 +1440,21 @@ class ProxyConfig:
else:
# default to file
config = await self._get_config_from_file(config_file_path=config_file_path)
## UPDATE CONFIG WITH DB
if prisma_client is not None:
config = await self._update_config_from_db(
config=config,
prisma_client=prisma_client,
store_model_in_db=store_model_in_db,
)
## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None)
config = self._check_for_os_environ_vars(config=config)
self.config = config
return config
async def load_config( # noqa: PLR0915
@ -2290,6 +2266,55 @@ class ProxyConfig:
pass_through_endpoints=general_settings["pass_through_endpoints"]
)
async def _update_config_from_db(
self,
prisma_client: PrismaClient,
config: dict,
store_model_in_db: Optional[bool],
):
if store_model_in_db is not True:
verbose_proxy_logger.info(
"'store_model_in_db' is not True, skipping db updates"
)
return config
_tasks = []
keys = [
"general_settings",
"router_settings",
"litellm_settings",
"environment_variables",
]
for k in keys:
response = prisma_client.get_generic_data(
key="param_name", value=k, table_name="config"
)
_tasks.append(response)
responses = await asyncio.gather(*_tasks)
for response in responses:
if response is not None:
param_name = getattr(response, "param_name", None)
verbose_proxy_logger.info(f"loading {param_name} settings from db")
if param_name == "litellm_settings":
verbose_proxy_logger.info(
f"litellm_settings: {response.param_value}"
)
param_value = getattr(response, "param_value", None)
if param_name is not None and param_value is not None:
# check if param_name is already in the config
if param_name in config:
if isinstance(config[param_name], dict):
config[param_name].update(param_value)
else:
config[param_name] = param_value
else:
# if it's not in the config - then add it
config[param_name] = param_value
return config
async def add_deployment(
self,
prisma_client: PrismaClient,
@ -2843,7 +2868,7 @@ class ProxyStartupEvent:
if (
proxy_logging_obj is not None
and proxy_logging_obj.slack_alerting_instance is not None
and proxy_logging_obj.slack_alerting_instance.alerting is not None
and prisma_client is not None
):
print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa
@ -2891,7 +2916,7 @@ class ProxyStartupEvent:
scheduler.start()
@classmethod
def _setup_prisma_client(
async def _setup_prisma_client(
cls,
database_url: Optional[str],
proxy_logging_obj: ProxyLogging,
@ -2910,11 +2935,15 @@ class ProxyStartupEvent:
except Exception as e:
raise e
await prisma_client.connect()
## Add necessary views to proxy ##
asyncio.create_task(
prisma_client.check_view_exists()
) # check if all necessary views exist. Don't block execution
# run a health check to ensure the DB is ready
await prisma_client.health_check()
return prisma_client
@ -2931,12 +2960,21 @@ async def startup_event():
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_client = ProxyStartupEvent._setup_prisma_client(
prisma_client = await ProxyStartupEvent._setup_prisma_client(
database_url=_db_url,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
)
## CHECK PREMIUM USER
verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
premium_user
)
)
if premium_user is False:
premium_user = _license_check.is_premium()
### LOAD CONFIG ###
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
@ -2984,21 +3022,6 @@ async def startup_event():
if isinstance(worker_config, dict):
await initialize(**worker_config)
## CHECK PREMIUM USER
verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
premium_user
)
)
if premium_user is False:
premium_user = _license_check.is_premium()
verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format(
premium_user
)
)
ProxyStartupEvent._initialize_startup_logging(
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
@ -3021,9 +3044,6 @@ async def startup_event():
prompt_injection_detection_obj.update_environment(router=llm_router)
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
if prisma_client is not None:
await prisma_client.connect()
if prisma_client is not None and master_key is not None:
ProxyStartupEvent._add_master_key_hash_to_db(
master_key=master_key,
@ -8723,7 +8743,7 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915
if k == "alert_to_webhook_url":
# check if slack is already enabled. if not, enable it
if "alerting" not in _existing_settings:
_existing_settings["alerting"].append("slack")
_existing_settings = {"alerting": ["slack"]}
elif isinstance(_existing_settings["alerting"], list):
if "slack" not in _existing_settings["alerting"]:
_existing_settings["alerting"].append("slack")

View file

@ -65,6 +65,7 @@ async def route_request(
Common helper to route the request
"""
router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data or "api_base" in data:
return getattr(litellm, f"{route_type}")(**data)

View file

@ -55,10 +55,6 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import (
@ -77,6 +73,7 @@ from litellm.proxy.db.create_views import (
create_missing_views,
should_create_missing_views,
)
from litellm.proxy.db.log_db_metrics import log_db_metrics
from litellm.proxy.db.prisma_client import PrismaWrapper
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
@ -137,83 +134,6 @@ def safe_deep_copy(data):
return new_data
def log_to_opentelemetry(func):
"""
Decorator to log the duration of a DB related function to ServiceLogger()
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
"""
@wraps(func)
async def wrapper(*args, **kwargs):
start_time: datetime = datetime.now()
try:
result = await func(*args, **kwargs)
end_time: datetime = datetime.now()
from litellm.proxy.proxy_server import proxy_logging_obj
if "PROXY" not in func.__name__:
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span", None),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)
elif (
# in litellm custom callbacks kwargs is passed as arg[0]
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
args is not None
and len(args) > 0
and isinstance(args[0], dict)
):
passed_kwargs = args[0]
parent_otel_span = _get_parent_otel_span_from_kwargs(
kwargs=passed_kwargs
)
if parent_otel_span is not None:
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.BATCH_WRITE_TO_DB,
call_type=func.__name__,
parent_otel_span=parent_otel_span,
duration=0.0,
start_time=start_time,
end_time=end_time,
event_metadata=metadata,
)
# end of logging to otel
return result
except Exception as e:
from litellm.proxy.proxy_server import proxy_logging_obj
end_time: datetime = datetime.now()
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
error=e,
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span"),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)
raise e
return wrapper
class InternalUsageCache:
def __init__(self, dual_cache: DualCache):
self.dual_cache: DualCache = dual_cache
@ -255,7 +175,7 @@ class InternalUsageCache:
local_only: bool = False,
**kwargs,
) -> None:
return await self.dual_cache.async_batch_set_cache(
return await self.dual_cache.async_set_cache_pipeline(
cache_list=cache_list,
local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span,
@ -1083,19 +1003,16 @@ class PrismaClient:
proxy_logging_obj: ProxyLogging,
http_client: Optional[Any] = None,
):
verbose_proxy_logger.debug(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object
self.proxy_logging_obj = proxy_logging_obj
self.iam_token_db_auth: Optional[bool] = str_to_bool(
os.getenv("IAM_TOKEN_DB_AUTH")
)
verbose_proxy_logger.debug("Creating Prisma Client..")
try:
from prisma import Prisma # type: ignore
except Exception:
raise Exception("Unable to find Prisma binaries.")
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
if http_client is not None:
self.db = PrismaWrapper(
original_prisma=Prisma(http=http_client),
@ -1114,7 +1031,7 @@ class PrismaClient:
else False
),
) # Client to connect to Prisma db
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
verbose_proxy_logger.debug("Success - Created Prisma Client")
def hash_token(self, token: str):
# Hash the string using SHA-256
@ -1400,6 +1317,7 @@ class PrismaClient:
return
@log_db_metrics
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
@ -1465,7 +1383,7 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
@log_to_opentelemetry
@log_db_metrics
async def get_data( # noqa: PLR0915
self,
token: Optional[Union[str, list]] = None,
@ -1506,9 +1424,7 @@ class PrismaClient:
# check if plain text or hash
if token is not None:
if isinstance(token, str):
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
hashed_token = _hash_token_if_needed(token=token)
verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
@ -1575,8 +1491,7 @@ class PrismaClient:
if token is not None:
where_filter["token"] = {}
if isinstance(token, str):
if token.startswith("sk-"):
token = self.hash_token(token=token)
token = _hash_token_if_needed(token=token)
where_filter["token"]["in"] = [token]
elif isinstance(token, list):
hashed_tokens = []
@ -1712,9 +1627,7 @@ class PrismaClient:
# check if plain text or hash
if token is not None:
if isinstance(token, str):
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
hashed_token = _hash_token_if_needed(token=token)
verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
@ -1994,8 +1907,7 @@ class PrismaClient:
if token is not None:
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
token = _hash_token_if_needed(token=token)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={"token": token}, # type: ignore
@ -2347,11 +2259,7 @@ class PrismaClient:
"""
start_time = time.time()
try:
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
sql_query = "SELECT 1"
# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
@ -2510,6 +2418,18 @@ def hash_token(token: str):
return hashed_token
def _hash_token_if_needed(token: str) -> str:
"""
Hash the token if it's a string and starts with "sk-"
Else return the token as is
"""
if token.startswith("sk-"):
return hash_token(token=token)
else:
return token
def _extract_from_regex(duration: str) -> Tuple[int, str]:
match = re.match(r"(\d+)(mo|[smhd]?)", duration)

View file

@ -339,11 +339,7 @@ class Router:
cache_config: Dict[str, Any] = {}
self.client_ttl = client_ttl
if redis_url is not None or (
redis_host is not None
and redis_port is not None
and redis_password is not None
):
if redis_url is not None or (redis_host is not None and redis_port is not None):
cache_type = "redis"
if redis_url is not None:
@ -556,6 +552,10 @@ class Router:
self.initialize_assistants_endpoint()
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
def initialize_assistants_endpoint(self):
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
@ -585,6 +585,7 @@ class Router:
def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
):
verbose_router_logger.info(f"Routing strategy: {routing_strategy}")
if (
routing_strategy == RoutingStrategy.LEAST_BUSY.value
or routing_strategy == RoutingStrategy.LEAST_BUSY
@ -912,6 +913,7 @@ class Router:
logging_obj=logging_obj,
parent_otel_span=parent_otel_span,
)
response = await _response
## CHECK CONTENT FILTER ERROR ##
@ -1681,78 +1683,6 @@ class Router:
)
raise e
async def amoderation(self, model: str, input: str, **kwargs):
try:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._amoderation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
asyncio.create_task(
send_llm_exception_alert(
litellm_router_instance=self,
request_kwargs=kwargs,
error_traceback_str=traceback.format_exc(),
original_exception=e,
)
)
raise e
async def _amoderation(self, model: str, input: str, **kwargs):
model_name = None
try:
verbose_router_logger.debug(
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy()
model_name = data["model"]
model_client = self._get_async_openai_model_client(
deployment=deployment,
kwargs=kwargs,
)
self.total_calls[model_name] += 1
timeout: Optional[Union[float, int]] = self._get_timeout(
kwargs=kwargs,
data=data,
)
response = await litellm.amoderation(
**{
**data,
"input": input,
"caching": self.cache_responses,
"client": model_client,
"timeout": timeout,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def arerank(self, model: str, **kwargs):
try:
kwargs["model"] = model
@ -2608,20 +2538,46 @@ class Router:
return final_results
#### ASSISTANTS API ####
#### PASSTHROUGH API ####
def factory_function(self, original_function: Callable):
async def _pass_through_moderation_endpoint_factory(
self,
original_function: Callable,
**kwargs,
):
if (
"model" in kwargs
and self.get_model_list(model_name=kwargs["model"]) is not None
):
deployment = await self.async_get_available_deployment(
model=kwargs["model"]
)
kwargs["model"] = deployment["litellm_params"]["model"]
return await original_function(**kwargs)
def factory_function(
self,
original_function: Callable,
call_type: Literal["assistants", "moderation"] = "assistants",
):
async def new_function(
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional["AsyncOpenAI"] = None,
**kwargs,
):
return await self._pass_through_assistants_endpoint_factory(
original_function=original_function,
custom_llm_provider=custom_llm_provider,
client=client,
**kwargs,
)
if call_type == "assistants":
return await self._pass_through_assistants_endpoint_factory(
original_function=original_function,
custom_llm_provider=custom_llm_provider,
client=client,
**kwargs,
)
elif call_type == "moderation":
return await self._pass_through_moderation_endpoint_factory( # type: ignore
original_function=original_function,
**kwargs,
)
return new_function
@ -2961,14 +2917,14 @@ class Router:
raise
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
retry_after = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
# sleeps for the length of the timeout
await asyncio.sleep(_timeout)
await asyncio.sleep(retry_after)
for current_attempt in range(num_retries):
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
@ -3598,6 +3554,15 @@ class Router:
# Catch all - if any exceptions default to cooling down
return True
def _has_default_fallbacks(self) -> bool:
if self.fallbacks is None:
return False
for fallback in self.fallbacks:
if isinstance(fallback, dict):
if "*" in fallback:
return True
return False
def _should_raise_content_policy_error(
self, model: str, response: ModelResponse, kwargs: dict
) -> bool:
@ -3614,6 +3579,7 @@ class Router:
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
if content_policy_fallbacks is not None:
fallback_model_group = None
@ -3624,6 +3590,8 @@ class Router:
if fallback_model_group is not None:
return True
elif self._has_default_fallbacks(): # default fallbacks set
return True
verbose_router_logger.info(
"Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(
@ -4178,7 +4146,9 @@ class Router:
model = _model
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
model_info = litellm.get_model_info(model=model)
model_info = litellm.get_model_info(
model="{}/{}".format(custom_llm_provider, model)
)
## CHECK USER SET MODEL INFO
user_model_info = deployment.get("model_info", {})
@ -4849,7 +4819,7 @@ class Router:
)
continue
except Exception as e:
verbose_router_logger.error("An error occurs - {}".format(str(e)))
verbose_router_logger.exception("An error occurs - {}".format(str(e)))
_litellm_params = deployment.get("litellm_params", {})
model_id = deployment.get("model_info", {}).get("id", "")
@ -5048,10 +5018,12 @@ class Router:
)
if len(healthy_deployments) == 0:
raise ValueError(
"{}. You passed in model={}. There is no 'model_name' with this string ".format(
RouterErrors.no_deployments_available.value, model
)
raise litellm.BadRequestError(
message="You passed in model={}. There is no 'model_name' with this string ".format(
model
),
model=model,
llm_provider="",
)
if litellm.model_alias_map and model in litellm.model_alias_map:

View file

@ -180,7 +180,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
@ -195,7 +194,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
deployment_rpm,
local_result,
),
headers={"retry-after": 60}, # type: ignore
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
@ -221,7 +220,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
deployment_rpm,
result,
),
headers={"retry-after": 60}, # type: ignore
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)

View file

@ -61,6 +61,24 @@ class PatternMatchRouter:
# return f"^{regex}$"
return re.escape(pattern).replace(r"\*", "(.*)")
def _return_pattern_matched_deployments(
self, matched_pattern: Match, deployments: List[Dict]
) -> List[Dict]:
new_deployments = []
for deployment in deployments:
new_deployment = copy.deepcopy(deployment)
new_deployment["litellm_params"]["model"] = (
PatternMatchRouter.set_deployment_model_name(
matched_pattern=matched_pattern,
litellm_deployment_litellm_model=deployment["litellm_params"][
"model"
],
)
)
new_deployments.append(new_deployment)
return new_deployments
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
"""
Route a requested model to the corresponding llm deployments based on the regex pattern
@ -79,8 +97,11 @@ class PatternMatchRouter:
if request is None:
return None
for pattern, llm_deployments in self.patterns.items():
if re.match(pattern, request):
return llm_deployments
pattern_match = re.match(pattern, request)
if pattern_match:
return self._return_pattern_matched_deployments(
matched_pattern=pattern_match, deployments=llm_deployments
)
except Exception as e:
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
@ -96,12 +117,28 @@ class PatternMatchRouter:
E.g.:
Case 1:
model_name: llmengine/* (can be any regex pattern or wildcard pattern)
litellm_params:
model: openai/*
if model_name = "llmengine/foo" -> model = "openai/foo"
Case 2:
model_name: llmengine/fo::*::static::*
litellm_params:
model: openai/fo::*::static::*
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz"
Case 3:
model_name: *meta.llama3*
litellm_params:
model: bedrock/meta.llama3*
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b"
"""
## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name
if "*" not in litellm_deployment_litellm_model:
return litellm_deployment_litellm_model
@ -112,10 +149,9 @@ class PatternMatchRouter:
dynamic_segments = matched_pattern.groups()
if len(dynamic_segments) > wildcard_count:
raise ValueError(
f"More wildcards in the deployment model name than the pattern. Wildcard count: {wildcard_count}, dynamic segments count: {len(dynamic_segments)}"
)
return (
matched_pattern.string
) # default to the user input, if unable to map based on wildcards.
# Replace the corresponding wildcards in the litellm model pattern with extracted segments
for segment in dynamic_segments:
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace(
@ -165,12 +201,7 @@ class PatternMatchRouter:
"""
pattern_match = self.get_pattern(model, custom_llm_provider)
if pattern_match:
provider_deployments = []
for deployment in pattern_match:
dep = copy.deepcopy(deployment)
dep["litellm_params"]["model"] = model
provider_deployments.append(dep)
return provider_deployments
return pattern_match
return []

Some files were not shown because too many files have changed in this diff Show more