Merge branch 'main' into custom_validation_docs

This commit is contained in:
Tyler Wagner 2025-03-20 15:49:36 -04:00 committed by GitHub
commit caacd7426c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
369 changed files with 19043 additions and 5075 deletions

View file

@ -49,7 +49,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.54.0
pip install openai==1.66.1
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -71,7 +71,7 @@ jobs:
pip install "Pillow==10.3.0"
pip install "jsonschema==4.22.0"
pip install "pytest-xdist==3.6.1"
pip install "websockets==10.4"
pip install "websockets==13.1.0"
pip uninstall posthog -y
- save_cache:
paths:
@ -168,7 +168,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.54.0
pip install openai==1.66.1
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -189,6 +189,7 @@ jobs:
pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0"
pip install "jsonschema==4.22.0"
pip install "websockets==13.1.0"
- save_cache:
paths:
- ./venv
@ -267,7 +268,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.54.0
pip install openai==1.66.1
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -288,6 +289,7 @@ jobs:
pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0"
pip install "jsonschema==4.22.0"
pip install "websockets==13.1.0"
- save_cache:
paths:
- ./venv
@ -511,7 +513,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.54.0
pip install openai==1.66.1
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -678,6 +680,48 @@ jobs:
paths:
- llm_translation_coverage.xml
- llm_translation_coverage
llm_responses_api_testing:
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project
steps:
- checkout
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-cov==5.0.0"
pip install "pytest-asyncio==0.21.1"
pip install "respx==0.21.1"
# Run pytest and generate JUnit XML report
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/llm_responses_api_testing --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
- run:
name: Rename the coverage files
command: |
mv coverage.xml llm_responses_api_coverage.xml
mv .coverage llm_responses_api_coverage
# Store test results
- store_test_results:
path: test-results
- persist_to_workspace:
root: .
paths:
- llm_responses_api_coverage.xml
- llm_responses_api_coverage
litellm_mapped_tests:
docker:
- image: cimg/python:3.11
@ -1234,7 +1278,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.54.0 "
pip install "openai==1.66.1"
- run:
name: Install Grype
command: |
@ -1309,13 +1353,13 @@ jobs:
command: |
pwd
ls
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
no_output_timeout: 120m
# Store test results
- store_test_results:
path: test-results
e2e_openai_misc_endpoints:
e2e_openai_endpoints:
machine:
image: ubuntu-2204:2023.10.1
resource_class: xlarge
@ -1370,7 +1414,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.54.0 "
pip install "openai==1.66.1"
# Run pytest and generate JUnit XML report
- run:
name: Build Docker image
@ -1432,7 +1476,7 @@ jobs:
command: |
pwd
ls
python -m pytest -s -vv tests/openai_misc_endpoints_tests --junitxml=test-results/junit.xml --durations=5
python -m pytest -s -vv tests/openai_endpoints_tests --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
# Store test results
@ -1492,7 +1536,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.54.0 "
pip install "openai==1.66.1"
- run:
name: Build Docker image
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
@ -1921,7 +1965,7 @@ jobs:
pip install "pytest-asyncio==0.21.1"
pip install "google-cloud-aiplatform==1.43.0"
pip install aiohttp
pip install "openai==1.54.0 "
pip install "openai==1.66.1"
pip install "assemblyai==0.37.0"
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
@ -2068,7 +2112,7 @@ jobs:
python -m venv venv
. venv/bin/activate
pip install coverage
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage batches_coverage litellm_proxy_security_tests_coverage
coverage combine llm_translation_coverage llm_responses_api_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage batches_coverage litellm_proxy_security_tests_coverage
coverage xml
- codecov/upload:
file: ./coverage.xml
@ -2197,7 +2241,7 @@ jobs:
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install aiohttp
pip install "openai==1.54.0 "
pip install "openai==1.66.1"
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
pip install "pytest==7.3.1"
@ -2387,7 +2431,7 @@ workflows:
only:
- main
- /litellm_.*/
- e2e_openai_misc_endpoints:
- e2e_openai_endpoints:
filters:
branches:
only:
@ -2429,6 +2473,12 @@ workflows:
only:
- main
- /litellm_.*/
- llm_responses_api_testing:
filters:
branches:
only:
- main
- /litellm_.*/
- litellm_mapped_tests:
filters:
branches:
@ -2468,6 +2518,7 @@ workflows:
- upload-coverage:
requires:
- llm_translation_testing
- llm_responses_api_testing
- litellm_mapped_tests
- batches_testing
- litellm_utils_testing
@ -2522,10 +2573,11 @@ workflows:
requires:
- local_testing
- build_and_test
- e2e_openai_misc_endpoints
- e2e_openai_endpoints
- load_testing
- test_bad_database_url
- llm_translation_testing
- llm_responses_api_testing
- litellm_mapped_tests
- batches_testing
- litellm_utils_testing

View file

@ -1,5 +1,5 @@
# used by CI/CD testing
openai==1.54.0
openai==1.66.1
python-dotenv
tiktoken
importlib_metadata

View file

@ -80,7 +80,6 @@ jobs:
permissions:
contents: read
packages: write
#
steps:
- name: Checkout repository
uses: actions/checkout@v4
@ -112,7 +111,11 @@ jobs:
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
tags: |
${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }},
${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }}
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm:main-{1}', env.REGISTRY, github.event.inputs.tag) || '' }},
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm:main-stable', env.REGISTRY) || '' }}
labels: ${{ steps.meta.outputs.labels }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
@ -151,8 +154,12 @@ jobs:
context: .
file: ./docker/Dockerfile.database
push: true
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
labels: ${{ steps.meta-database.outputs.labels }}
tags: |
${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }},
${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-database:main-{1}', env.REGISTRY, github.event.inputs.tag) || '' }},
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-database:main-stable', env.REGISTRY) || '' }}
labels: ${{ steps.meta-database.outputs.labels }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
build-and-push-image-non_root:
@ -190,7 +197,11 @@ jobs:
context: .
file: ./docker/Dockerfile.non_root
push: true
tags: ${{ steps.meta-non_root.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-non_root.outputs.tags }}-${{ github.event.inputs.release_type }}
tags: |
${{ steps.meta-non_root.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }},
${{ steps.meta-non_root.outputs.tags }}-${{ github.event.inputs.release_type }}
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-non_root:main-{1}', env.REGISTRY, github.event.inputs.tag) || '' }},
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-non_root:main-stable', env.REGISTRY) || '' }}
labels: ${{ steps.meta-non_root.outputs.labels }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
@ -229,7 +240,11 @@ jobs:
context: .
file: ./litellm-js/spend-logs/Dockerfile
push: true
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
tags: |
${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }},
${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-spend_logs:main-{1}', env.REGISTRY, github.event.inputs.tag) || '' }},
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-spend_logs:main-stable', env.REGISTRY) || '' }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
build-and-push-helm-chart:

4
.gitignore vendored
View file

@ -79,3 +79,7 @@ litellm/proxy/_experimental/out/model_hub.html
litellm/proxy/application.log
tests/llm_translation/vertex_test_account.json
tests/llm_translation/test_vertex_key.json
litellm/proxy/migrations/0_init/migration.sql
litellm/proxy/db/migrations/0_init/migration.sql
litellm/proxy/db/migrations/*
litellm/proxy/migrations/*

View file

@ -1,7 +1,7 @@
# LiteLLM Makefile
# Simple Makefile for running tests and basic development tasks
.PHONY: help test test-unit test-integration
.PHONY: help test test-unit test-integration lint format
# Default target
help:
@ -9,6 +9,14 @@ help:
@echo " make test - Run all tests"
@echo " make test-unit - Run unit tests"
@echo " make test-integration - Run integration tests"
@echo " make test-unit-helm - Run helm unit tests"
install-dev:
poetry install --with dev
lint: install-dev
poetry run pip install types-requests types-setuptools types-redis types-PyYAML
cd litellm && poetry run mypy . --ignore-missing-imports
# Testing
test:
@ -18,4 +26,7 @@ test-unit:
poetry run pytest tests/litellm/
test-integration:
poetry run pytest tests/ -k "not litellm"
poetry run pytest tests/ -k "not litellm"
test-unit-helm:
helm unittest -f 'tests/*.yaml' deploy/charts/litellm-helm

View file

@ -18,7 +18,7 @@ type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
version: 0.4.1
version: 0.4.2
# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to

View file

@ -22,6 +22,8 @@ If `db.useStackgresOperator` is used (not yet implemented):
| Name | Description | Value |
| ---------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----- |
| `replicaCount` | The number of LiteLLM Proxy pods to be deployed | `1` |
| `masterkeySecretName` | The name of the Kubernetes Secret that contains the Master API Key for LiteLLM. If not specified, use the generated secret name. | N/A |
| `masterkeySecretKey` | The key within the Kubernetes Secret that contains the Master API Key for LiteLLM. If not specified, use `masterkey` as the key. | N/A |
| `masterkey` | The Master API Key for LiteLLM. If not specified, a random key is generated. | N/A |
| `environmentSecrets` | An optional array of Secret object names. The keys and values in these secrets will be presented to the LiteLLM proxy pod as environment variables. See below for an example Secret object. | `[]` |
| `environmentConfigMaps` | An optional array of ConfigMap object names. The keys and values in these configmaps will be presented to the LiteLLM proxy pod as environment variables. See below for an example Secret object. | `[]` |

View file

@ -78,8 +78,8 @@ spec:
- name: PROXY_MASTER_KEY
valueFrom:
secretKeyRef:
name: {{ include "litellm.fullname" . }}-masterkey
key: masterkey
name: {{ .Values.masterkeySecretName | default (printf "%s-masterkey" (include "litellm.fullname" .)) }}
key: {{ .Values.masterkeySecretKey | default "masterkey" }}
{{- if .Values.redis.enabled }}
- name: REDIS_HOST
value: {{ include "litellm.redis.serviceName" . }}

View file

@ -1,3 +1,4 @@
{{- if not .Values.masterkeySecretName }}
{{ $masterkey := (.Values.masterkey | default (randAlphaNum 17)) }}
apiVersion: v1
kind: Secret
@ -5,4 +6,5 @@ metadata:
name: {{ include "litellm.fullname" . }}-masterkey
data:
masterkey: {{ $masterkey | b64enc }}
type: Opaque
type: Opaque
{{- end }}

View file

@ -52,3 +52,31 @@ tests:
- equal:
path: spec.template.spec.affinity.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution.nodeSelectorTerms[0].matchExpressions[0].values[0]
value: antarctica-east1
- it: should work without masterkeySecretName or masterkeySecretKey
template: deployment.yaml
set:
masterkeySecretName: ""
masterkeySecretKey: ""
asserts:
- contains:
path: spec.template.spec.containers[0].env
content:
name: PROXY_MASTER_KEY
valueFrom:
secretKeyRef:
name: RELEASE-NAME-litellm-masterkey
key: masterkey
- it: should work with masterkeySecretName and masterkeySecretKey
template: deployment.yaml
set:
masterkeySecretName: my-secret
masterkeySecretKey: my-key
asserts:
- contains:
path: spec.template.spec.containers[0].env
content:
name: PROXY_MASTER_KEY
valueFrom:
secretKeyRef:
name: my-secret
key: my-key

View file

@ -0,0 +1,18 @@
suite: test masterkey secret
templates:
- secret-masterkey.yaml
tests:
- it: should create a secret if masterkeySecretName is not set
template: secret-masterkey.yaml
set:
masterkeySecretName: ""
asserts:
- isKind:
of: Secret
- it: should not create a secret if masterkeySecretName is set
template: secret-masterkey.yaml
set:
masterkeySecretName: my-secret
asserts:
- hasDocuments:
count: 0

View file

@ -75,6 +75,12 @@ ingress:
# masterkey: changeit
# if set, use this secret for the master key; otherwise, autogenerate a new one
masterkeySecretName: ""
# if set, use this secret key for the master key; otherwise, use the default key
masterkeySecretKey: ""
# The elements within proxy_config are rendered as config.yaml for the proxy
# Examples: https://github.com/BerriAI/litellm/tree/main/litellm/proxy/example_config_yaml
# Reference: https://docs.litellm.ai/docs/proxy/configs

View file

@ -20,10 +20,18 @@ services:
STORE_MODEL_IN_DB: "True" # allows adding models to proxy via UI
env_file:
- .env # Load local .env file
depends_on:
- db # Indicates that this service depends on the 'db' service, ensuring 'db' starts first
healthcheck: # Defines the health check configuration for the container
test: [ "CMD", "curl", "-f", "http://localhost:4000/health/liveliness || exit 1" ] # Command to execute for health check
interval: 30s # Perform health check every 30 seconds
timeout: 10s # Health check command times out after 10 seconds
retries: 3 # Retry up to 3 times if health check fails
start_period: 40s # Wait 40 seconds after container start before beginning health checks
db:
image: postgres
image: postgres:16
restart: always
environment:
POSTGRES_DB: litellm
@ -31,6 +39,8 @@ services:
POSTGRES_PASSWORD: dbpassword9090
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data # Persists Postgres data across container restarts
healthcheck:
test: ["CMD-SHELL", "pg_isready -d litellm -U llmproxy"]
interval: 1s
@ -53,6 +63,8 @@ services:
volumes:
prometheus_data:
driver: local
postgres_data:
name: litellm_postgres_data # Named volume for Postgres data persistence
# ...rest of your docker-compose config if any

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [BETA] `/v1/messages`
# /v1/messages [BETA]
LiteLLM provides a BETA endpoint in the spec of Anthropic's `/v1/messages` endpoint.

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Assistants API
# /assistants
Covers Threads, Messages, Assistants.

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [BETA] Batches API
# /batches
Covers Batches, Files

View file

@ -3,7 +3,13 @@ import TabItem from '@theme/TabItem';
# Prompt Caching
For OpenAI + Anthropic + Deepseek, LiteLLM follows the OpenAI prompt caching usage object format:
Supported Providers:
- OpenAI (`openai/`)
- Anthropic API (`anthropic/`)
- Bedrock (`bedrock/`, `bedrock/invoke/`, `bedrock/converse`) ([All models bedrock supports prompt caching on](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html))
- Deepseek API (`deepseek/`)
For the supported providers, LiteLLM follows the OpenAI prompt caching usage object format:
```bash
"usage": {
@ -499,4 +505,4 @@ curl -L -X GET 'http://0.0.0.0:4000/v1/model/info' \
</TabItem>
</Tabs>
This checks our maintained [model info/cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
This checks our maintained [model info/cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Embeddings
# /embeddings
## Quick Start
```python

View file

@ -34,9 +34,9 @@ You can use our cloud product where we setup a dedicated instance for you.
Professional Support can assist with LLM/Provider integrations, deployment, upgrade management, and LLM Provider troubleshooting. We cant solve your own infrastructure-related issues but we will guide you to fix them.
- 1 hour for Sev0 issues
- 6 hours for Sev1
- 24h for Sev2-Sev3 between 7am 7pm PT (Monday through Saturday)
- 1 hour for Sev0 issues - 100% production traffic is failing
- 6 hours for Sev1 - <100% production traffic is failing
- 24h for Sev2-Sev3 between 7am 7pm PT (Monday through Saturday) - setup issues e.g. Redis working on our end, but not on your infrastructure.
- 72h SLA for patching vulnerabilities in the software.
**We can offer custom SLAs** based on your needs and the severity of the issue

View file

@ -8,7 +8,7 @@ Here are the core requirements for any PR submitted to LiteLLM
- [ ] Add testing, **Adding at least 1 test is a hard requirement** - [see details](#2-adding-testing-to-your-pr)
- [ ] Ensure your PR passes the following tests:
- [ ] [Unit Tests](#3-running-unit-tests)
- [ ] Formatting / Linting Tests
- [ ] [Formatting / Linting Tests](#35-running-linting-tests)
- [ ] Keep scope as isolated as possible. As a general rule, your changes should address 1 specific problem at a time
@ -56,6 +56,16 @@ run the following command on the root of the litellm directory
make test-unit
```
## 3.5 Running Linting Tests
run the following command on the root of the litellm directory
```shell
make lint
```
LiteLLM uses mypy for linting. On ci/cd we also run `black` for formatting.
## 4. Submit a PR with your changes!
- push your fork to your GitHub repo

View file

@ -2,7 +2,7 @@
import TabItem from '@theme/TabItem';
import Tabs from '@theme/Tabs';
# Files API
# /files
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [Beta] Fine-tuning API
# /fine_tuning
:::info

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Moderation
# /moderations
### Usage

View file

@ -79,6 +79,7 @@ aws_session_name: Optional[str],
aws_profile_name: Optional[str],
aws_role_name: Optional[str],
aws_web_identity_token: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str],
```
### 2. Start the proxy
@ -1262,6 +1263,473 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
</Tabs>
## Bedrock Imported Models (Deepseek, Deepseek R1)
### Deepseek R1
This is a separate route, as the chat template is different.
| Property | Details |
|----------|---------|
| Provider Route | `bedrock/deepseek_r1/{model_arn}` |
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
response = completion(
model="bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/deepseek_r1/{your-model-arn}
messages=[{"role": "user", "content": "Tell me a joke"}],
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: DeepSeek-R1-Distill-Llama-70B
litellm_params:
model: bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
### Deepseek (not R1)
| Property | Details |
|----------|---------|
| Provider Route | `bedrock/llama/{model_arn}` |
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
response = completion(
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn}
messages=[{"role": "user", "content": "Tell me a joke"}],
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: DeepSeek-R1-Distill-Llama-70B
litellm_params:
model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
## Provisioned throughput models
To use provisioned throughput Bedrock models pass
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
- `model_id=provisioned-model-arn`
Completion
```python
import litellm
response = litellm.completion(
model="bedrock/anthropic.claude-instant-v1",
model_id="provisioned-model-arn",
messages=[{"content": "Hello, how are you?", "role": "user"}]
)
```
Embedding
```python
import litellm
response = litellm.embedding(
model="bedrock/amazon.titan-embed-text-v1",
model_id="provisioned-model-arn",
input=["hi"],
)
```
## Supported AWS Bedrock Models
Here's an example of using a bedrock model with LiteLLM. For a complete list, refer to the [model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
| Model Name | Command |
|----------------------------|------------------------------------------------------------------|
| Anthropic Claude-V3.5 Sonnet | `completion(model='bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V3 sonnet | `completion(model='bedrock/anthropic.claude-3-sonnet-20240229-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V3 Haiku | `completion(model='bedrock/anthropic.claude-3-haiku-20240307-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V3 Opus | `completion(model='bedrock/anthropic.claude-3-opus-20240229-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Amazon Titan Express | `completion(model='bedrock/amazon.titan-text-express-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Cohere Command | `completion(model='bedrock/cohere.command-text-v14', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Mid | `completion(model='bedrock/ai21.j2-mid-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Ultra | `completion(model='bedrock/ai21.j2-ultra-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 Jamba-Instruct | `completion(model='bedrock/ai21.jamba-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 Chat 13b | `completion(model='bedrock/meta.llama2-13b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 Chat 70b | `completion(model='bedrock/meta.llama2-70b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Mistral 7B Instruct | `completion(model='bedrock/mistral.mistral-7b-instruct-v0:2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Mixtral 8x7B Instruct | `completion(model='bedrock/mistral.mixtral-8x7b-instruct-v0:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
## Bedrock Embedding
### API keys
This can be set as env variables or passed as **params to litellm.embedding()**
```python
import os
os.environ["AWS_ACCESS_KEY_ID"] = "" # Access key
os.environ["AWS_SECRET_ACCESS_KEY"] = "" # Secret access key
os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
```
### Usage
```python
from litellm import embedding
response = embedding(
model="bedrock/amazon.titan-embed-text-v1",
input=["good morning from litellm"],
)
print(response)
```
## Supported AWS Bedrock Embedding Models
| Model Name | Usage | Supported Additional OpenAI params |
|----------------------|---------------------------------------------|-----|
| Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py#L59) |
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py#L53)
| Titan Multimodal Embeddings | `embedding(model="bedrock/amazon.titan-embed-image-v1", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py#L28) |
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/cohere_transformation.py#L18)
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/cohere_transformation.py#L18)
### Advanced - [Drop Unsupported Params](https://docs.litellm.ai/docs/completion/drop_params#openai-proxy-usage)
### Advanced - [Pass model/provider-specific Params](https://docs.litellm.ai/docs/completion/provider_specific_params#proxy-usage)
## Image Generation
Use this for stable diffusion, and amazon nova canvas on bedrock
### Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import image_generation
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = image_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v0",
)
print(f"response: {response}")
```
**Set optional params**
```python
import os
from litellm import image_generation
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = image_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v0",
### OPENAI-COMPATIBLE ###
size="128x512", # width=128, height=512
### PROVIDER-SPECIFIC ### see `AmazonStabilityConfig` in bedrock.py for all params
seed=30
)
print(f"response: {response}")
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: amazon.nova-canvas-v1:0
litellm_params:
model: bedrock/amazon.nova-canvas-v1:0
aws_region_name: "us-east-1"
aws_secret_access_key: my-key # OPTIONAL - all boto3 auth params supported
aws_secret_access_id: my-id # OPTIONAL - all boto3 auth params supported
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/images/generations' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer $LITELLM_VIRTUAL_KEY' \
-d '{
"model": "amazon.nova-canvas-v1:0",
"prompt": "A cute baby sea otter"
}'
```
</TabItem>
</Tabs>
## Supported AWS Bedrock Image Generation Models
| Model Name | Function Call |
|----------------------|---------------------------------------------|
| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` |
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` |
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` |
## Rerank API
Use Bedrock's Rerank API in the Cohere `/rerank` format.
Supported Cohere Rerank Params
- `model` - the foundation model ARN
- `query` - the query to rerank against
- `documents` - the list of documents to rerank
- `top_n` - the number of results to return
<Tabs>
<TabItem label="SDK" value="sdk">
```python
from litellm import rerank
import os
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = rerank(
model="bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", # provide the model ARN - get this here https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html
query="hello",
documents=["hello", "world"],
top_n=2,
)
print(response)
```
</TabItem>
<TabItem label="PROXY" value="proxy">
1. Setup config.yaml
```yaml
model_list:
- model_name: bedrock-rerank
litellm_params:
model: bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
```
2. Start proxy server
```bash
litellm --config config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
```bash
curl http://0.0.0.0:4000/rerank \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "bedrock-rerank",
"query": "What is the capital of the United States?",
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country."
],
"top_n": 3
}'
```
</TabItem>
</Tabs>
## Bedrock Application Inference Profile
Use Bedrock Application Inference Profile to track costs for projects on AWS.
You can either pass it in the model name - `model="bedrock/arn:...` or as a separate `model_id="arn:..` param.
### Set via `model_id`
<Tabs>
<TabItem label="SDK" value="sdk">
```python
from litellm import completion
import os
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = completion(
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
messages=[{"role": "user", "content": "Hello, how are you?"}],
model_id="arn:aws:bedrock:eu-central-1:000000000000:application-inference-profile/a0a0a0a0a0a0",
)
print(response)
```
</TabItem>
<TabItem label="PROXY" value="proxy">
1. Setup config.yaml
```yaml
model_list:
- model_name: anthropic-claude-3-5-sonnet
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
# You have to set the ARN application inference profile in the model_id parameter
model_id: arn:aws:bedrock:eu-central-1:000000000000:application-inference-profile/a0a0a0a0a0a0
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer $LITELLM_API_KEY' \
-d '{
"model": "anthropic-claude-3-5-sonnet",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "List 5 important events in the XIX century"
}
]
}
]
}'
```
</TabItem>
</Tabs>
## Boto3 - Authentication
### Passing credentials as parameters - Completion()
@ -1497,7 +1965,7 @@ response = completion(
aws_bedrock_client=bedrock,
)
```
## Calling via Internal Proxy
## Calling via Internal Proxy (not bedrock url compatible)
Use the `bedrock/converse_like/model` endpoint to call bedrock converse model via your internal proxy.
@ -1563,359 +2031,3 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
```bash
https://some-api-url/models
```
## Bedrock Imported Models (Deepseek, Deepseek R1)
### Deepseek R1
This is a separate route, as the chat template is different.
| Property | Details |
|----------|---------|
| Provider Route | `bedrock/deepseek_r1/{model_arn}` |
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
response = completion(
model="bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/deepseek_r1/{your-model-arn}
messages=[{"role": "user", "content": "Tell me a joke"}],
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: DeepSeek-R1-Distill-Llama-70B
litellm_params:
model: bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
### Deepseek (not R1)
| Property | Details |
|----------|---------|
| Provider Route | `bedrock/llama/{model_arn}` |
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
response = completion(
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn}
messages=[{"role": "user", "content": "Tell me a joke"}],
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: DeepSeek-R1-Distill-Llama-70B
litellm_params:
model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
## Provisioned throughput models
To use provisioned throughput Bedrock models pass
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
- `model_id=provisioned-model-arn`
Completion
```python
import litellm
response = litellm.completion(
model="bedrock/anthropic.claude-instant-v1",
model_id="provisioned-model-arn",
messages=[{"content": "Hello, how are you?", "role": "user"}]
)
```
Embedding
```python
import litellm
response = litellm.embedding(
model="bedrock/amazon.titan-embed-text-v1",
model_id="provisioned-model-arn",
input=["hi"],
)
```
## Supported AWS Bedrock Models
Here's an example of using a bedrock model with LiteLLM. For a complete list, refer to the [model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
| Model Name | Command |
|----------------------------|------------------------------------------------------------------|
| Anthropic Claude-V3.5 Sonnet | `completion(model='bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V3 sonnet | `completion(model='bedrock/anthropic.claude-3-sonnet-20240229-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V3 Haiku | `completion(model='bedrock/anthropic.claude-3-haiku-20240307-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V3 Opus | `completion(model='bedrock/anthropic.claude-3-opus-20240229-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Amazon Titan Express | `completion(model='bedrock/amazon.titan-text-express-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Cohere Command | `completion(model='bedrock/cohere.command-text-v14', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Mid | `completion(model='bedrock/ai21.j2-mid-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Ultra | `completion(model='bedrock/ai21.j2-ultra-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 Jamba-Instruct | `completion(model='bedrock/ai21.jamba-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 Chat 13b | `completion(model='bedrock/meta.llama2-13b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 Chat 70b | `completion(model='bedrock/meta.llama2-70b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Mistral 7B Instruct | `completion(model='bedrock/mistral.mistral-7b-instruct-v0:2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Mixtral 8x7B Instruct | `completion(model='bedrock/mistral.mixtral-8x7b-instruct-v0:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
## Bedrock Embedding
### API keys
This can be set as env variables or passed as **params to litellm.embedding()**
```python
import os
os.environ["AWS_ACCESS_KEY_ID"] = "" # Access key
os.environ["AWS_SECRET_ACCESS_KEY"] = "" # Secret access key
os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
```
### Usage
```python
from litellm import embedding
response = embedding(
model="bedrock/amazon.titan-embed-text-v1",
input=["good morning from litellm"],
)
print(response)
```
## Supported AWS Bedrock Embedding Models
| Model Name | Usage | Supported Additional OpenAI params |
|----------------------|---------------------------------------------|-----|
| Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py#L59) |
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py#L53)
| Titan Multimodal Embeddings | `embedding(model="bedrock/amazon.titan-embed-image-v1", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py#L28) |
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/cohere_transformation.py#L18)
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/cohere_transformation.py#L18)
### Advanced - [Drop Unsupported Params](https://docs.litellm.ai/docs/completion/drop_params#openai-proxy-usage)
### Advanced - [Pass model/provider-specific Params](https://docs.litellm.ai/docs/completion/provider_specific_params#proxy-usage)
## Image Generation
Use this for stable diffusion on bedrock
### Usage
```python
import os
from litellm import image_generation
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = image_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v0",
)
print(f"response: {response}")
```
**Set optional params**
```python
import os
from litellm import image_generation
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = image_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v0",
### OPENAI-COMPATIBLE ###
size="128x512", # width=128, height=512
### PROVIDER-SPECIFIC ### see `AmazonStabilityConfig` in bedrock.py for all params
seed=30
)
print(f"response: {response}")
```
## Supported AWS Bedrock Image Generation Models
| Model Name | Function Call |
|----------------------|---------------------------------------------|
| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` |
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` |
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` |
## Rerank API
Use Bedrock's Rerank API in the Cohere `/rerank` format.
Supported Cohere Rerank Params
- `model` - the foundation model ARN
- `query` - the query to rerank against
- `documents` - the list of documents to rerank
- `top_n` - the number of results to return
<Tabs>
<TabItem label="SDK" value="sdk">
```python
from litellm import rerank
import os
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = rerank(
model="bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", # provide the model ARN - get this here https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html
query="hello",
documents=["hello", "world"],
top_n=2,
)
print(response)
```
</TabItem>
<TabItem label="PROXY" value="proxy">
1. Setup config.yaml
```yaml
model_list:
- model_name: bedrock-rerank
litellm_params:
model: bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
```
2. Start proxy server
```bash
litellm --config config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
```bash
curl http://0.0.0.0:4000/rerank \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "bedrock-rerank",
"query": "What is the capital of the United States?",
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country."
],
"top_n": 3
}'
```
</TabItem>
</Tabs>

View file

@ -57,7 +57,7 @@ messages = [{ "content": "Hello, how are you?","role": "user"}]
# litellm proxy call
response = completion(
model="litellm_proxy/your-model-name",
messages,
messages=messages,
api_base = "your-litellm-proxy-url",
api_key = "your-litellm-proxy-api-key"
)
@ -76,7 +76,7 @@ messages = [{ "content": "Hello, how are you?","role": "user"}]
# openai call
response = completion(
model="litellm_proxy/your-model-name",
messages,
messages=messages,
api_base = "your-litellm-proxy-url",
stream=True
)

View file

@ -0,0 +1,90 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Snowflake
| Property | Details |
|-------|-------|
| Description | The Snowflake Cortex LLM REST API lets you access the COMPLETE function via HTTP POST requests|
| Provider Route on LiteLLM | `snowflake/` |
| Link to Provider Doc | [Snowflake ↗](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api) |
| Base URL | [https://{account-id}.snowflakecomputing.com/api/v2/cortex/inference:complete/](https://{account-id}.snowflakecomputing.com/api/v2/cortex/inference:complete) |
| Supported OpenAI Endpoints | `/chat/completions`, `/completions` |
Currently, Snowflake's REST API does not have an endpoint for `snowflake-arctic-embed` embedding models. If you want to use these embedding models with Litellm, you can call them through our Hugging Face provider.
Find the Arctic Embed models [here](https://huggingface.co/collections/Snowflake/arctic-embed-661fd57d50fab5fc314e4c18) on Hugging Face.
## Supported OpenAI Parameters
```
"temperature",
"max_tokens",
"top_p",
"response_format"
```
## API KEYS
Snowflake does have API keys. Instead, you access the Snowflake API with your JWT token and account identifier.
```python
import os
os.environ["SNOWFLAKE_JWT"] = "YOUR JWT"
os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT IDENTIFIER"
```
## Usage
```python
from litellm import completion
## set ENV variables
os.environ["SNOWFLAKE_JWT"] = "YOUR JWT"
os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT IDENTIFIER"
# Snowflake call
response = completion(
model="snowflake/mistral-7b",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```
## Usage with LiteLLM Proxy
#### 1. Required env variables
```bash
export SNOWFLAKE_JWT=""
export SNOWFLAKE_ACCOUNT_ID = ""
```
#### 2. Start the proxy~
```yaml
model_list:
- model_name: mistral-7b
litellm_params:
model: snowflake/mistral-7b
api_key: YOUR_API_KEY
api_base: https://YOUR-ACCOUNT-ID.snowflakecomputing.com/api/v2/cortex/inference:complete
```
```bash
litellm --config /path/to/config.yaml
```
#### 3. Test it
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "snowflake/mistral-7b",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
]
}
'
```

View file

@ -10,17 +10,13 @@ Role-based access control (RBAC) is based on Organizations, Teams and Internal U
## Roles
**Admin Roles**
- `proxy_admin`: admin over the platform
- `proxy_admin_viewer`: can login, view all keys, view all spend. **Cannot** create keys/delete keys/add new users
**Organization Roles**
- `org_admin`: admin over the organization. Can create teams and users within their organization
**Internal User Roles**
- `internal_user`: can login, view/create/delete their own keys, view their spend. **Cannot** add new users.
- `internal_user_viewer`: can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users.
| Role Type | Role Name | Permissions |
|-----------|-----------|-------------|
| **Admin** | `proxy_admin` | Admin over the platform |
| | `proxy_admin_viewer` | Can login, view all keys, view all spend. **Cannot** create keys/delete keys/add new users |
| **Organization** | `org_admin` | Admin over the organization. Can create teams and users within their organization |
| **Internal User** | `internal_user` | Can login, view/create/delete their own keys, view their spend. **Cannot** add new users |
| | `internal_user_viewer` | Can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users |
## Onboarding Organizations

View file

@ -177,7 +177,7 @@ general_settings:
| use_x_forwarded_for | str | If true, uses the X-Forwarded-For header to get the client IP address |
| service_account_settings | List[Dict[str, Any]] | Set `service_account_settings` if you want to create settings that only apply to service account keys (Doc on service accounts)[./service_accounts.md] |
| image_generation_model | str | The default model to use for image generation - ignores model set in request |
| store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) |
| store_model_in_db | boolean | If true, enables storing model + credential information in the DB. |
| store_prompts_in_spend_logs | boolean | If true, allows prompts and responses to be stored in the spend logs table. |
| max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. |
| max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. |
@ -499,9 +499,11 @@ router_settings:
| SMTP_USERNAME | Username for SMTP authentication (do not set if SMTP does not require auth)
| SPEND_LOGS_URL | URL for retrieving spend logs
| SSL_CERTIFICATE | Path to the SSL certificate file
| SSL_SECURITY_LEVEL | [BETA] Security level for SSL/TLS connections. E.g. `DEFAULT@SECLEVEL=1`
| SSL_VERIFY | Flag to enable or disable SSL certificate verification
| SUPABASE_KEY | API key for Supabase service
| SUPABASE_URL | Base URL for Supabase instance
| STORE_MODEL_IN_DB | If true, enables storing model + credential information in the DB.
| TEST_EMAIL_ADDRESS | Email address used for testing purposes
| UI_LOGO_PATH | Path to the logo image used in the UI
| UI_PASSWORD | Password for accessing the UI
@ -513,4 +515,3 @@ router_settings:
| UPSTREAM_LANGFUSE_SECRET_KEY | Secret key for upstream Langfuse authentication
| USE_AWS_KMS | Flag to enable AWS Key Management Service for encryption
| WEBHOOK_URL | URL for receiving webhooks from external services

View file

@ -448,6 +448,34 @@ model_list:
s/o to [@David Manouchehri](https://www.linkedin.com/in/davidmanouchehri/) for helping with this.
### Centralized Credential Management
Define credentials once and reuse them across multiple models. This helps with:
- Secret rotation
- Reducing config duplication
```yaml
model_list:
- model_name: gpt-4o
litellm_params:
model: azure/gpt-4o
litellm_credential_name: default_azure_credential # Reference credential below
credential_list:
- credential_name: default_azure_credential
credential_values:
api_key: os.environ/AZURE_API_KEY # Load from environment
api_base: os.environ/AZURE_API_BASE
api_version: "2023-05-15"
credential_info:
description: "Production credentials for EU region"
```
#### Key Parameters
- `credential_name`: Unique identifier for the credential set
- `credential_values`: Key-value pairs of credentials/secrets (supports `os.environ/` syntax)
- `credential_info`: Key-value pairs of user provided credentials information. No key-value pairs are required, but the dictionary must exist.
### Load API Keys from Secret Managers (Azure Vault, etc)
[**Using Secret Managers with LiteLLM Proxy**](../secret)
@ -641,4 +669,4 @@ docker run --name litellm-proxy \
ghcr.io/berriai/litellm-database:main-latest
```
</TabItem>
</Tabs>
</Tabs>

View file

@ -0,0 +1,194 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Custom Prompt Management
Connect LiteLLM to your prompt management system with custom hooks.
## Overview
<Image
img={require('../../img/custom_prompt_management.png')}
style={{width: '100%', display: 'block', margin: '2rem auto'}}
/>
## How it works
## Quick Start
### 1. Create Your Custom Prompt Manager
Create a class that inherits from `CustomPromptManagement` to handle prompt retrieval and formatting:
**Example Implementation**
Create a new file called `custom_prompt.py` and add this code. The key method here is `get_chat_completion_prompt` you can implement custom logic to retrieve and format prompts based on the `prompt_id` and `prompt_variables`.
```python
from typing import List, Tuple, Optional
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
class MyCustomPromptManagement(CustomPromptManagement):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Retrieve and format prompts based on prompt_id.
Returns:
- model: The model to use
- messages: The formatted messages
- non_default_params: Optional parameters like temperature
"""
# Example matching the diagram: Add system message for prompt_id "1234"
if prompt_id == "1234":
# Prepend system message while preserving existing messages
new_messages = [
{"role": "system", "content": "Be a good Bot!"},
] + messages
return model, new_messages, non_default_params
# Default: Return original messages if no prompt_id match
return model, messages, non_default_params
prompt_management = MyCustomPromptManagement()
```
### 2. Configure Your Prompt Manager in LiteLLM `config.yaml`
```yaml
model_list:
- model_name: gpt-4
litellm_params:
model: openai/gpt-4
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
callbacks: custom_prompt.prompt_management # sets litellm.callbacks = [prompt_management]
```
### 3. Start LiteLLM Gateway
<Tabs>
<TabItem value="docker" label="Docker Run">
Mount your `custom_logger.py` on the LiteLLM Docker container.
```shell
docker run -d \
-p 4000:4000 \
-e OPENAI_API_KEY=$OPENAI_API_KEY \
--name my-app \
-v $(pwd)/my_config.yaml:/app/config.yaml \
-v $(pwd)/custom_logger.py:/app/custom_logger.py \
my-app:latest \
--config /app/config.yaml \
--port 4000 \
--detailed_debug \
```
</TabItem>
<TabItem value="py" label="litellm pip">
```shell
litellm --config config.yaml --detailed_debug
```
</TabItem>
</Tabs>
### 4. Test Your Custom Prompt Manager
When you pass `prompt_id="1234"`, the custom prompt manager will add a system message "Be a good Bot!" to your conversation:
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
from openai import OpenAI
client = OpenAI(
api_key="sk-1234",
base_url="http://0.0.0.0:4000"
)
response = client.chat.completions.create(
model="gemini-1.5-pro",
messages=[{"role": "user", "content": "hi"}],
prompt_id="1234"
)
print(response.choices[0].message.content)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
chat = ChatOpenAI(
model="gpt-4",
openai_api_key="sk-1234",
openai_api_base="http://0.0.0.0:4000",
extra_body={
"prompt_id": "1234"
}
)
messages = []
response = chat(messages)
print(response.content)
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell
curl -X POST http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gemini-1.5-pro",
"messages": [{"role": "user", "content": "hi"}],
"prompt_id": "1234"
}'
```
</TabItem>
</Tabs>
The request will be transformed from:
```json
{
"model": "gemini-1.5-pro",
"messages": [{"role": "user", "content": "hi"}],
"prompt_id": "1234"
}
```
To:
```json
{
"model": "gemini-1.5-pro",
"messages": [
{"role": "system", "content": "Be a good Bot!"},
{"role": "user", "content": "hi"}
]
}
```

View file

@ -37,7 +37,7 @@ guardrails:
- guardrail_name: aim-protected-app
litellm_params:
guardrail: aim
mode: pre_call # 'during_call' is also available
mode: [pre_call, post_call] # "During_call" is also available
api_key: os.environ/AIM_API_KEY
api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost
```

View file

@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [BETA] Prompt Management
# Prompt Management
:::info
@ -12,9 +12,10 @@ This feature is currently in beta, and might change unexpectedly. We expect this
Run experiments or change the specific model (e.g. from gpt-4o to gpt4o-mini finetune) from your prompt management tool (e.g. Langfuse) instead of making changes in the application.
Supported Integrations:
- [Langfuse](https://langfuse.com/docs/prompts/get-started)
- [Humanloop](../observability/humanloop)
| Supported Integrations | Link |
|------------------------|------|
| Langfuse | [Get Started](https://langfuse.com/docs/prompts/get-started) |
| Humanloop | [Get Started](../observability/humanloop) |
## Quick Start

View file

@ -102,7 +102,19 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
</TabItem>
</Tabs>
## Advanced - Set Accepted JWT Scope Names
## Advanced
### Multiple OIDC providers
Use this if you want LiteLLM to validate your JWT against multiple OIDC providers (e.g. Google Cloud, GitHub Auth)
Set `JWT_PUBLIC_KEY_URL` in your environment to a comma-separated list of URLs for your OIDC providers.
```bash
export JWT_PUBLIC_KEY_URL="https://demo.duendesoftware.com/.well-known/openid-configuration/jwks,https://accounts.google.com/.well-known/openid-configuration/jwks"
```
### Set Accepted JWT Scope Names
Change the string in JWT 'scopes', that litellm evaluates to see if a user has admin access.
@ -114,7 +126,7 @@ general_settings:
admin_jwt_scope: "litellm-proxy-admin"
```
## Tracking End-Users / Internal Users / Team / Org
### Tracking End-Users / Internal Users / Team / Org
Set the field in the jwt token, which corresponds to a litellm user / team / org.
@ -156,7 +168,7 @@ scope: ["litellm-proxy-admin",...]
scope: "litellm-proxy-admin ..."
```
## Control model access with Teams
### Control model access with Teams
1. Specify the JWT field that contains the team ids, that the user belongs to.
@ -207,7 +219,7 @@ OIDC Auth for API: [**See Walkthrough**](https://www.loom.com/share/00fe2deab59a
- If all checks pass, allow the request
## Advanced - Custom Validate
### Custom JWT Validate
This section allows you to add custom logic to intercept and perform validation of the JWT token.
@ -215,7 +227,7 @@ This can occur when there is additional logic that is needed to execute against
> _Note_: You can expect the JWT will have ran the typical decrypting of the public key, token decoding, and expiration time checks before executing the custom validation function.
### 1. Setup custom validate function
#### 1. Setup custom validate function
```python
from typing import Any, Literal
@ -236,7 +248,7 @@ def my_custom_validate(token: dict[str, Any]) -> Literal[True]:
return True
```
### 2. Setup config.yaml
#### 2. Setup config.yaml
```yaml
general_settings:
@ -249,7 +261,7 @@ general_settings:
custom_validate: custom_validate.my_custom_validate # 👈 custom validate function
```
### 3. Test the flow
#### 3. Test the flow
**Expected JWT**
@ -271,7 +283,7 @@ general_settings:
}
```
## Advanced - Allowed Routes
### Allowed Routes
Configure which routes a JWT can access via the config.
@ -303,7 +315,7 @@ general_settings:
team_allowed_routes: ["/v1/chat/completions"] # 👈 Set accepted routes
```
## Advanced - Caching Public Keys
### Caching Public Keys
Control how long public keys are cached for (in seconds).
@ -317,7 +329,7 @@ general_settings:
public_key_ttl: 600 # 👈 KEY CHANGE
```
## Advanced - Custom JWT Field
### Custom JWT Field
Set a custom field in which the team_id exists. By default, the 'client_id' field is checked.
@ -329,14 +341,7 @@ general_settings:
team_id_jwt_field: "client_id" # 👈 KEY CHANGE
```
## All Params
[**See Code**](https://github.com/BerriAI/litellm/blob/b204f0c01c703317d812a1553363ab0cb989d5b6/litellm/proxy/_types.py#L95)
## Advanced - Block Teams
### Block Teams
To block all requests for a certain team id, use `/team/block`
@ -363,7 +368,7 @@ curl --location 'http://0.0.0.0:4000/team/unblock' \
```
## Advanced - Upsert Users + Allowed Email Domains
### Upsert Users + Allowed Email Domains
Allow users who belong to a specific email domain, automatic access to the proxy.
@ -501,3 +506,7 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
]
}'
```
## All JWT Params
[**See Code**](https://github.com/BerriAI/litellm/blob/b204f0c01c703317d812a1553363ab0cb989d5b6/litellm/proxy/_types.py#L95)

View file

@ -0,0 +1,55 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Adding LLM Credentials
You can add LLM provider credentials on the UI. Once you add credentials you can re-use them when adding new models
## Add a credential + model
### 1. Navigate to LLM Credentials page
Go to Models -> LLM Credentials -> Add Credential
<Image img={require('../../img/ui_cred_add.png')} />
### 2. Add credentials
Select your LLM provider, enter your API Key and click "Add Credential"
**Note: Credentials are based on the provider, if you select Vertex AI then you will see `Vertex Project`, `Vertex Location` and `Vertex Credentials` fields**
<Image img={require('../../img/ui_add_cred_2.png')} />
### 3. Use credentials when adding a model
Go to Add Model -> Existing Credentials -> Select your credential in the dropdown
<Image img={require('../../img/ui_cred_3.png')} />
## Create a Credential from an existing model
Use this if you have already created a model and want to store the model credentials for future use
### 1. Select model to create a credential from
Go to Models -> Select your model -> Credential -> Create Credential
<Image img={require('../../img/ui_cred_4.png')} />
### 2. Use new credential when adding a model
Go to Add Model -> Existing Credentials -> Select your credential in the dropdown
<Image img={require('../../img/use_model_cred.png')} />
## Frequently Asked Questions
How are credentials stored?
Credentials in the DB are encrypted/decrypted using `LITELLM_SALT_KEY`, if set. If not, then they are encrypted using `LITELLM_MASTER_KEY`. These keys should be kept secret and not shared with others.

View file

@ -0,0 +1,55 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# UI Logs Page
View Spend, Token Usage, Key, Team Name for Each Request to LiteLLM
<Image img={require('../../img/ui_request_logs.png')}/>
## Overview
| Log Type | Tracked by Default |
|----------|-------------------|
| Success Logs | ✅ Yes |
| Error Logs | ✅ Yes |
| Request/Response Content Stored | ❌ No by Default, **opt in with `store_prompts_in_spend_logs`** |
**By default LiteLLM does not track the request and response content.**
## Tracking - Request / Response Content in Logs Page
If you want to view request and response content on LiteLLM Logs, you need to opt in with this setting
```yaml
general_settings:
store_prompts_in_spend_logs: true
```
<Image img={require('../../img/ui_request_logs_content.png')}/>
## Stop storing Error Logs in DB
If you do not want to store error logs in DB, you can opt out with this setting
```yaml
general_settings:
disable_error_logs: True # Only disable writing error logs to DB, regular spend logs will still be written unless `disable_spend_logs: True`
```
## Stop storing Spend Logs in DB
If you do not want to store spend logs in DB, you can opt out with this setting
```yaml
general_settings:
disable_spend_logs: True # Disable writing spend logs to DB
```

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Realtime Endpoints
# /realtime
Use this to loadbalance across Azure + OpenAI.

View file

@ -3,11 +3,20 @@ import TabItem from '@theme/TabItem';
# 'Thinking' / 'Reasoning Content'
:::info
Requires LiteLLM v1.63.0+
:::
Supported Providers:
- Deepseek (`deepseek/`)
- Anthropic API (`anthropic/`)
- Bedrock (Anthropic + Deepseek) (`bedrock/`)
- Vertex AI (Anthropic) (`vertexai/`)
- OpenRouter (`openrouter/`)
LiteLLM will standardize the `reasoning_content` in the response and `thinking_blocks` in the assistant message.
```python
"message": {

View file

@ -1,4 +1,4 @@
# Rerank
# /rerank
:::tip

View file

@ -0,0 +1,117 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# /responses [Beta]
LiteLLM provides a BETA endpoint in the spec of [OpenAI's `/responses` API](https://platform.openai.com/docs/api-reference/responses)
| Feature | Supported | Notes |
|---------|-----------|--------|
| Cost Tracking | ✅ | Works with all supported models |
| Logging | ✅ | Works across all integrations |
| End-user Tracking | ✅ | |
| Streaming | ✅ | |
| Fallbacks | ✅ | Works between supported models |
| Loadbalancing | ✅ | Works between supported models |
| Supported LiteLLM Versions | 1.63.8+ | |
| Supported LLM providers | `openai` | |
## Usage
## Create a model response
<Tabs>
<TabItem value="litellm-sdk" label="LiteLLM SDK">
#### Non-streaming
```python
import litellm
# Non-streaming response
response = litellm.responses(
model="gpt-4o",
input="Tell me a three sentence bedtime story about a unicorn.",
max_output_tokens=100
)
print(response)
```
#### Streaming
```python
import litellm
# Streaming response
response = litellm.responses(
model="gpt-4o",
input="Tell me a three sentence bedtime story about a unicorn.",
stream=True
)
for event in response:
print(event)
```
</TabItem>
<TabItem value="proxy" label="OpenAI SDK with LiteLLM Proxy">
First, add this to your litellm proxy config.yaml:
```yaml
model_list:
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
```
Start your LiteLLM proxy:
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
Then use the OpenAI SDK pointed to your proxy:
#### Non-streaming
```python
from openai import OpenAI
# Initialize client with your proxy URL
client = OpenAI(
base_url="http://localhost:4000", # Your proxy URL
api_key="your-api-key" # Your proxy API key
)
# Non-streaming response
response = client.responses.create(
model="gpt-4o",
input="Tell me a three sentence bedtime story about a unicorn."
)
print(response)
```
#### Streaming
```python
from openai import OpenAI
# Initialize client with your proxy URL
client = OpenAI(
base_url="http://localhost:4000", # Your proxy URL
api_key="your-api-key" # Your proxy API key
)
# Streaming response
response = client.responses.create(
model="gpt-4o",
input="Tell me a three sentence bedtime story about a unicorn.",
stream=True
)
for event in response:
print(event)
```
</TabItem>
</Tabs>

View file

@ -830,7 +830,7 @@ asyncio.run(router_acompletion())
Set `weight` on a deployment to pick one deployment more often than others.
This works across **ALL** routing strategies.
This works across **simple-shuffle** routing strategy (this is the default, if no routing strategy is selected).
<Tabs>
<TabItem value="sdk" label="SDK">

View file

@ -96,7 +96,7 @@ litellm --config /path/to/config.yaml
```
### Using K/V pairs in 1 AWS Secret
#### Using K/V pairs in 1 AWS Secret
You can read multiple keys from a single AWS Secret using the `primary_secret_name` parameter:

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Text Completion
# /completions
### Usage
<Tabs>

Binary file not shown.

After

Width:  |  Height:  |  Size: 346 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 371 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 283 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 567 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 344 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 282 KiB

View file

@ -706,12 +706,13 @@
}
},
"node_modules/@babel/helpers": {
"version": "7.26.0",
"resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.26.0.tgz",
"integrity": "sha512-tbhNuIxNcVb21pInl3ZSjksLCvgdZy9KwJ8brv993QtIVKJBBkYXz4q4ZbAv31GdnC+R90np23L5FbEBlthAEw==",
"version": "7.26.10",
"resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.26.10.tgz",
"integrity": "sha512-UPYc3SauzZ3JGgj87GgZ89JVdC5dj0AoetR5Bw6wj4niittNyFh6+eOGonYvJ1ao6B8lEa3Q3klS7ADZ53bc5g==",
"license": "MIT",
"dependencies": {
"@babel/template": "^7.25.9",
"@babel/types": "^7.26.0"
"@babel/template": "^7.26.9",
"@babel/types": "^7.26.10"
},
"engines": {
"node": ">=6.9.0"
@ -796,11 +797,12 @@
}
},
"node_modules/@babel/parser": {
"version": "7.26.3",
"resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.3.tgz",
"integrity": "sha512-WJ/CvmY8Mea8iDXo6a7RK2wbmJITT5fN3BEkRuFlxVyNx8jOKIIhmC4fSkTcPcf8JyavbBwIe6OpiCOBXt/IcA==",
"version": "7.26.10",
"resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.10.tgz",
"integrity": "sha512-6aQR2zGE/QFi8JpDLjUZEPYOs7+mhKXm86VaKFiLP35JQwQb6bwUE+XbvkH0EptsYhbNBSUGaUBLKqxH1xSgsA==",
"license": "MIT",
"dependencies": {
"@babel/types": "^7.26.3"
"@babel/types": "^7.26.10"
},
"bin": {
"parser": "bin/babel-parser.js"
@ -2157,9 +2159,10 @@
}
},
"node_modules/@babel/runtime-corejs3": {
"version": "7.26.0",
"resolved": "https://registry.npmjs.org/@babel/runtime-corejs3/-/runtime-corejs3-7.26.0.tgz",
"integrity": "sha512-YXHu5lN8kJCb1LOb9PgV6pvak43X2h4HvRApcN5SdWeaItQOzfn1hgP6jasD6KWQyJDBxrVmA9o9OivlnNJK/w==",
"version": "7.26.10",
"resolved": "https://registry.npmjs.org/@babel/runtime-corejs3/-/runtime-corejs3-7.26.10.tgz",
"integrity": "sha512-uITFQYO68pMEYR46AHgQoyBg7KPPJDAbGn4jUTIRgCFJIp88MIBUianVOplhZDEec07bp9zIyr4Kp0FCyQzmWg==",
"license": "MIT",
"dependencies": {
"core-js-pure": "^3.30.2",
"regenerator-runtime": "^0.14.0"
@ -2169,13 +2172,14 @@
}
},
"node_modules/@babel/template": {
"version": "7.25.9",
"resolved": "https://registry.npmjs.org/@babel/template/-/template-7.25.9.tgz",
"integrity": "sha512-9DGttpmPvIxBb/2uwpVo3dqJ+O6RooAFOS+lB+xDqoE2PVCE8nfoHMdZLpfCQRLwvohzXISPZcgxt80xLfsuwg==",
"version": "7.26.9",
"resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz",
"integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==",
"license": "MIT",
"dependencies": {
"@babel/code-frame": "^7.25.9",
"@babel/parser": "^7.25.9",
"@babel/types": "^7.25.9"
"@babel/code-frame": "^7.26.2",
"@babel/parser": "^7.26.9",
"@babel/types": "^7.26.9"
},
"engines": {
"node": ">=6.9.0"
@ -2199,9 +2203,10 @@
}
},
"node_modules/@babel/types": {
"version": "7.26.3",
"resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.3.tgz",
"integrity": "sha512-vN5p+1kl59GVKMvTHt55NzzmYVxprfJD+ql7U9NFIfKCBkYE55LYtS+WtPlaYOyzydrKI8Nezd+aZextrd+FMA==",
"version": "7.26.10",
"resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.10.tgz",
"integrity": "sha512-emqcG3vHrpxUKTrxcblR36dcrcoRDvKmnL/dCL6ZsHaShW80qxCAcNhzQZrpeM765VzEos+xOi4s+r4IXzTwdQ==",
"license": "MIT",
"dependencies": {
"@babel/helper-string-parser": "^7.25.9",
"@babel/helper-validator-identifier": "^7.25.9"

View file

@ -0,0 +1,180 @@
---
title: v1.63.11-stable
slug: v1.63.11-stable
date: 2025-03-15T10:00:00
authors:
- name: Krrish Dholakia
title: CEO, LiteLLM
url: https://www.linkedin.com/in/krish-d/
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
- name: Ishaan Jaffer
title: CTO, LiteLLM
url: https://www.linkedin.com/in/reffajnaahsi/
image_url: https://pbs.twimg.com/profile_images/1613813310264340481/lz54oEiB_400x400.jpg
tags: [credential management, thinking content, responses api, snowflake]
hide_table_of_contents: false
---
import Image from '@theme/IdealImage';
These are the changes since `v1.63.2-stable`.
This release is primarily focused on:
- [Beta] Responses API Support
- Snowflake Cortex Support, Amazon Nova Image Generation
- UI - Credential Management, re-use credentials when adding new models
- UI - Test Connection to LLM Provider before adding a model
:::info
This release will be live on 03/16/2025
:::
<!-- <Image img={require('../../img/release_notes/v16311_release.jpg')} /> -->
## Known Issues
- 🚨 Known issue on Azure OpenAI - We don't recommend upgrading if you use Azure OpenAI. This version failed our Azure OpenAI load test
## Docker Run LiteLLM Proxy
```
docker run
-e STORE_MODEL_IN_DB=True
-p 4000:4000
ghcr.io/berriai/litellm:main-v1.63.11-stable
```
## Demo Instance
Here's a Demo Instance to test changes:
- Instance: https://demo.litellm.ai/
- Login Credentials:
- Username: admin
- Password: sk-1234
## New Models / Updated Models
- Image Generation support for Amazon Nova Canvas [Getting Started](https://docs.litellm.ai/docs/providers/bedrock#image-generation)
- Add pricing for Jamba new models [PR](https://github.com/BerriAI/litellm/pull/9032/files)
- Add pricing for Amazon EU models [PR](https://github.com/BerriAI/litellm/pull/9056/files)
- Add Bedrock Deepseek R1 model pricing [PR](https://github.com/BerriAI/litellm/pull/9108/files)
- Update Gemini pricing: Gemma 3, Flash 2 thinking update, LearnLM [PR](https://github.com/BerriAI/litellm/pull/9190/files)
- Mark Cohere Embedding 3 models as Multimodal [PR](https://github.com/BerriAI/litellm/pull/9176/commits/c9a576ce4221fc6e50dc47cdf64ab62736c9da41)
- Add Azure Data Zone pricing [PR](https://github.com/BerriAI/litellm/pull/9185/files#diff-19ad91c53996e178c1921cbacadf6f3bae20cfe062bd03ee6bfffb72f847ee37)
- LiteLLM Tracks cost for `azure/eu` and `azure/us` models
## LLM Translation
<Image img={require('../../img/release_notes/responses_api.png')} />
1. **New Endpoints**
- [Beta] POST `/responses` API. [Getting Started](https://docs.litellm.ai/docs/response_api)
2. **New LLM Providers**
- Snowflake Cortex [Getting Started](https://docs.litellm.ai/docs/providers/snowflake)
3. **New LLM Features**
- Support OpenRouter `reasoning_content` on streaming [Getting Started](https://docs.litellm.ai/docs/reasoning_content)
4. **Bug Fixes**
- OpenAI: Return `code`, `param` and `type` on bad request error [More information on litellm exceptions](https://docs.litellm.ai/docs/exception_mapping)
- Bedrock: Fix converse chunk parsing to only return empty dict on tool use [PR](https://github.com/BerriAI/litellm/pull/9166)
- Bedrock: Support extra_headers [PR](https://github.com/BerriAI/litellm/pull/9113)
- Azure: Fix Function Calling Bug & Update Default API Version to `2025-02-01-preview` [PR](https://github.com/BerriAI/litellm/pull/9191)
- Azure: Fix AI services URL [PR](https://github.com/BerriAI/litellm/pull/9185)
- Vertex AI: Handle HTTP 201 status code in response [PR](https://github.com/BerriAI/litellm/pull/9193)
- Perplexity: Fix incorrect streaming response [PR](https://github.com/BerriAI/litellm/pull/9081)
- Triton: Fix streaming completions bug [PR](https://github.com/BerriAI/litellm/pull/8386)
- Deepgram: Support bytes.IO when handling audio files for transcription [PR](https://github.com/BerriAI/litellm/pull/9071)
- Ollama: Fix "system" role has become unacceptable [PR](https://github.com/BerriAI/litellm/pull/9261)
- All Providers (Streaming): Fix String `data:` stripped from entire content in streamed responses [PR](https://github.com/BerriAI/litellm/pull/9070)
## Spend Tracking Improvements
1. Support Bedrock converse cache token tracking [Getting Started](https://docs.litellm.ai/docs/completion/prompt_caching)
2. Cost Tracking for Responses API [Getting Started](https://docs.litellm.ai/docs/response_api)
3. Fix Azure Whisper cost tracking [Getting Started](https://docs.litellm.ai/docs/audio_transcription)
## UI
### Re-Use Credentials on UI
You can now onboard LLM provider credentials on LiteLLM UI. Once these credentials are added you can re-use them when adding new models [Getting Started](https://docs.litellm.ai/docs/proxy/ui_credentials)
<Image img={require('../../img/release_notes/credentials.jpg')} />
### Test Connections before adding models
Before adding a model you can test the connection to the LLM provider to verify you have setup your API Base + API Key correctly
<Image img={require('../../img/release_notes/litellm_test_connection.gif')} />
### General UI Improvements
1. Add Models Page
- Allow adding Cerebras, Sambanova, Perplexity, Fireworks, Openrouter, TogetherAI Models, Text-Completion OpenAI on Admin UI
- Allow adding EU OpenAI models
- Fix: Instantly show edit + deletes to models
2. Keys Page
- Fix: Instantly show newly created keys on Admin UI (don't require refresh)
- Fix: Allow clicking into Top Keys when showing users Top API Key
- Fix: Allow Filter Keys by Team Alias, Key Alias and Org
- UI Improvements: Show 100 Keys Per Page, Use full height, increase width of key alias
3. Users Page
- Fix: Show correct count of internal user keys on Users Page
- Fix: Metadata not updating in Team UI
4. Logs Page
- UI Improvements: Keep expanded log in focus on LiteLLM UI
- UI Improvements: Minor improvements to logs page
- Fix: Allow internal user to query their own logs
- Allow switching off storing Error Logs in DB [Getting Started](https://docs.litellm.ai/docs/proxy/ui_logs)
5. Sign In/Sign Out
- Fix: Correctly use `PROXY_LOGOUT_URL` when set [Getting Started](https://docs.litellm.ai/docs/proxy/self_serve#setting-custom-logout-urls)
## Security
1. Support for Rotating Master Keys [Getting Started](https://docs.litellm.ai/docs/proxy/master_key_rotations)
2. Fix: Internal User Viewer Permissions, don't allow `internal_user_viewer` role to see `Test Key Page` or `Create Key Button` [More information on role based access controls](https://docs.litellm.ai/docs/proxy/access_control)
3. Emit audit logs on All user + model Create/Update/Delete endpoints [Getting Started](https://docs.litellm.ai/docs/proxy/multiple_admins)
4. JWT
- Support multiple JWT OIDC providers [Getting Started](https://docs.litellm.ai/docs/proxy/token_auth)
- Fix JWT access with Groups not working when team is assigned All Proxy Models access
5. Using K/V pairs in 1 AWS Secret [Getting Started](https://docs.litellm.ai/docs/secret#using-kv-pairs-in-1-aws-secret)
## Logging Integrations
1. Prometheus: Track Azure LLM API latency metric [Getting Started](https://docs.litellm.ai/docs/proxy/prometheus#request-latency-metrics)
2. Athina: Added tags, user_feedback and model_options to additional_keys which can be sent to Athina [Getting Started](https://docs.litellm.ai/docs/observability/athina_integration)
## Performance / Reliability improvements
1. Redis + litellm router - Fix Redis cluster mode for litellm router [PR](https://github.com/BerriAI/litellm/pull/9010)
## General Improvements
1. OpenWebUI Integration - display `thinking` tokens
- Guide on getting started with LiteLLM x OpenWebUI. [Getting Started](https://docs.litellm.ai/docs/tutorials/openweb_ui)
- Display `thinking` tokens on OpenWebUI (Bedrock, Anthropic, Deepseek) [Getting Started](https://docs.litellm.ai/docs/tutorials/openweb_ui#render-thinking-content-on-openweb-ui)
<Image img={require('../../img/litellm_thinking_openweb.gif')} />
## Complete Git Diff
[Here's the complete git diff](https://github.com/BerriAI/litellm/compare/v1.63.2-stable...v1.63.11-stable)

View file

@ -101,7 +101,9 @@ const sidebars = {
"proxy/admin_ui_sso",
"proxy/self_serve",
"proxy/public_teams",
"proxy/custom_sso"
"proxy/custom_sso",
"proxy/ui_credentials",
"proxy/ui_logs"
],
},
{
@ -231,6 +233,7 @@ const sidebars = {
"providers/sambanova",
"providers/custom_llm_server",
"providers/petals",
"providers/snowflake"
],
},
{
@ -273,7 +276,7 @@ const sidebars = {
items: [
{
type: "category",
label: "Chat",
label: "/chat/completions",
link: {
type: "generated-index",
title: "Chat Completions",
@ -286,12 +289,13 @@ const sidebars = {
"completion/usage",
],
},
"response_api",
"text_completion",
"embedding/supported_embedding",
"anthropic_unified",
{
type: "category",
label: "Image",
label: "/images",
items: [
"image_generation",
"image_variations",
@ -299,7 +303,7 @@ const sidebars = {
},
{
type: "category",
label: "Audio",
label: "/audio",
"items": [
"audio_transcription",
"text_to_speech",
@ -361,8 +365,12 @@ const sidebars = {
],
},
{
type: "doc",
id: "proxy/prompt_management"
type: "category",
label: "[Beta] Prompt Management",
items: [
"proxy/prompt_management",
"proxy/custom_prompt_management"
],
},
{
type: "category",

View file

@ -163,7 +163,7 @@ class AporiaGuardrail(CustomGuardrail):
pass
async def async_moderation_hook( ### 👈 KEY CHANGE ###
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@ -173,6 +173,7 @@ class AporiaGuardrail(CustomGuardrail):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
from litellm.proxy.common_utils.callback_utils import (

View file

@ -94,6 +94,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""

View file

@ -107,6 +107,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""

View file

@ -126,6 +126,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""

View file

@ -31,7 +31,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
#### CALL HOOKS - proxy only ####
async def async_moderation_hook( ### 👈 KEY CHANGE ###
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@ -41,6 +41,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
text = ""

View file

@ -8,12 +8,14 @@ import os
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching.caching import Cache, DualCache, RedisCache, InMemoryCache
from litellm.caching.llm_caching_handler import LLMClientCache
from litellm.types.llms.bedrock import COHERE_EMBEDDING_INPUT_TYPES
from litellm.types.utils import (
ImageObject,
BudgetConfig,
all_litellm_params,
all_litellm_params as _litellm_completion_params,
CredentialItem,
) # maintain backwards compatibility for root param
from litellm._logging import (
set_verbose,
@ -180,6 +182,7 @@ cloudflare_api_key: Optional[str] = None
baseten_key: Optional[str] = None
aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None
snowflake_key: Optional[str] = None
common_cloud_provider_auth_params: dict = {
"params": ["project", "region_name", "token"],
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
@ -189,15 +192,17 @@ ssl_verify: Union[str, bool] = True
ssl_certificate: Optional[str] = None
disable_streaming_logging: bool = False
disable_add_transform_inline_image_block: bool = False
in_memory_llm_clients_cache: InMemoryCache = InMemoryCache()
in_memory_llm_clients_cache: LLMClientCache = LLMClientCache()
safe_memory_mode: bool = False
enable_azure_ad_token_refresh: Optional[bool] = False
### DEFAULT AZURE API VERSION ###
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
AZURE_DEFAULT_API_VERSION = "2025-02-01-preview" # this is updated to the latest
### DEFAULT WATSONX API VERSION ###
WATSONX_DEFAULT_API_VERSION = "2024-03-13"
### COHERE EMBEDDINGS DEFAULT TYPE ###
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
### CREDENTIALS ###
credential_list: List[CredentialItem] = []
### GUARDRAILS ###
llamaguard_model_name: Optional[str] = None
openai_moderations_model_name: Optional[str] = None
@ -412,6 +417,7 @@ cerebras_models: List = []
galadriel_models: List = []
sambanova_models: List = []
assemblyai_models: List = []
snowflake_models: List = []
def is_bedrock_pricing_only_model(key: str) -> bool:
@ -565,6 +571,8 @@ def add_known_models():
assemblyai_models.append(key)
elif value.get("litellm_provider") == "jina_ai":
jina_ai_models.append(key)
elif value.get("litellm_provider") == "snowflake":
snowflake_models.append(key)
add_known_models()
@ -594,6 +602,7 @@ ollama_models = ["llama2"]
maritalk_models = ["maritalk"]
model_list = (
open_ai_chat_completion_models
+ open_ai_text_completion_models
@ -638,6 +647,7 @@ model_list = (
+ azure_text_models
+ assemblyai_models
+ jina_ai_models
+ snowflake_models
)
model_list_set = set(model_list)
@ -693,6 +703,7 @@ models_by_provider: dict = {
"sambanova": sambanova_models,
"assemblyai": assemblyai_models,
"jina_ai": jina_ai_models,
"snowflake": snowflake_models,
}
# mapping for those models which have larger equivalents
@ -809,6 +820,7 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase.chat.transformation import PredibaseConfig
from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.snowflake.chat.transformation import SnowflakeConfig
from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
@ -899,6 +911,7 @@ from .llms.bedrock.chat.invoke_transformations.base_invoke_transformation import
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config
from .llms.bedrock.image.amazon_nova_canvas_transformation import AmazonNovaCanvasConfig
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
AmazonTitanMultimodalEmbeddingG1Config,
@ -921,11 +934,14 @@ from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from .llms.openai.chat.o_series_transformation import (
OpenAIOSeriesConfig as OpenAIO1Config, # maintain backwards compatibility
OpenAIOSeriesConfig,
)
from .llms.snowflake.chat.transformation import SnowflakeConfig
openaiOSeriesConfig = OpenAIOSeriesConfig()
from .llms.openai.chat.gpt_transformation import (
OpenAIGPTConfig,
@ -1010,6 +1026,7 @@ from .batches.main import *
from .batch_completion.main import * # type: ignore
from .rerank_api.main import *
from .llms.anthropic.experimental_pass_through.messages.handler import *
from .responses.main import *
from .realtime_api.main import _arealtime
from .fine_tuning.main import *
from .files.main import *

View file

@ -182,9 +182,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
)
verbose_logger.debug(
"init_redis_cluster: startup nodes are being initialized."
)
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
@ -307,7 +305,6 @@ def get_redis_async_client(
return _init_async_redis_sentinel(redis_kwargs)
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
)

View file

@ -15,6 +15,7 @@ import litellm
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
exception_type,
get_litellm_params,
get_llm_provider,
get_secret,
supports_httpx_timeout,
@ -86,6 +87,7 @@ def get_assistants(
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -169,6 +171,7 @@ def get_assistants(
max_retries=optional_params.max_retries,
client=client,
aget_assistants=aget_assistants, # type: ignore
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -270,6 +273,7 @@ def create_assistants(
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -371,6 +375,7 @@ def create_assistants(
client=client,
async_create_assistants=async_create_assistants,
create_assistant_data=create_assistant_data,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -445,6 +450,8 @@ def delete_assistant(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
async_delete_assistants: Optional[bool] = kwargs.pop(
"async_delete_assistants", None
)
@ -544,6 +551,7 @@ def delete_assistant(
max_retries=optional_params.max_retries,
client=client,
async_delete_assistants=async_delete_assistants,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -639,6 +647,7 @@ def create_thread(
"""
acreate_thread = kwargs.get("acreate_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -731,6 +740,7 @@ def create_thread(
max_retries=optional_params.max_retries,
client=client,
acreate_thread=acreate_thread,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -795,7 +805,7 @@ def get_thread(
"""Get the thread object, given a thread_id"""
aget_thread = kwargs.pop("aget_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
@ -884,6 +894,7 @@ def get_thread(
max_retries=optional_params.max_retries,
client=client,
aget_thread=aget_thread,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -972,6 +983,7 @@ def add_message(
_message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata
)
litellm_params_dict = get_litellm_params(**kwargs)
optional_params = GenericLiteLLMParams(**kwargs)
message_data = get_optional_params_add_message(
@ -1068,6 +1080,7 @@ def add_message(
max_retries=optional_params.max_retries,
client=client,
a_add_message=a_add_message,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -1139,6 +1152,7 @@ def get_messages(
) -> SyncCursorPage[OpenAIMessage]:
aget_messages = kwargs.pop("aget_messages", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -1225,6 +1239,7 @@ def get_messages(
max_retries=optional_params.max_retries,
client=client,
aget_messages=aget_messages,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -1337,6 +1352,7 @@ def run_thread(
"""Run a given thread + assistant."""
arun_thread = kwargs.pop("arun_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -1437,6 +1453,7 @@ def run_thread(
max_retries=optional_params.max_retries,
client=client,
arun_thread=arun_thread,
litellm_params=litellm_params_dict,
) # type: ignore
else:
raise litellm.exceptions.BadRequestError(

View file

@ -111,6 +111,7 @@ def create_batch(
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
_is_async = kwargs.pop("acreate_batch", False) is True
litellm_params = get_litellm_params(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -217,6 +218,7 @@ def create_batch(
timeout=timeout,
max_retries=optional_params.max_retries,
create_batch_data=_create_batch_request,
litellm_params=litellm_params,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
@ -320,15 +322,12 @@ def retrieve_batch(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
**kwargs,
)
litellm_logging_obj.update_environment_variables(
model=None,
@ -424,6 +423,7 @@ def retrieve_batch(
timeout=timeout,
max_retries=optional_params.max_retries,
retrieve_batch_data=_retrieve_batch_request,
litellm_params=litellm_params,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
@ -526,6 +526,10 @@ def list_batches(
try:
# set API KEY
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
**kwargs,
)
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
@ -603,6 +607,7 @@ def list_batches(
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
litellm_params=litellm_params,
)
else:
raise litellm.exceptions.BadRequestError(
@ -678,6 +683,10 @@ def cancel_batch(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
**kwargs,
)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
@ -765,6 +774,7 @@ def cancel_batch(
timeout=timeout,
max_retries=optional_params.max_retries,
cancel_batch_data=_cancel_batch_request,
litellm_params=litellm_params,
)
else:
raise litellm.exceptions.BadRequestError(

View file

@ -790,6 +790,7 @@ class LLMCachingHandler:
- Else append the chunk to self.async_streaming_chunks
"""
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
@ -800,7 +801,6 @@ class LLMCachingHandler:
streaming_chunks=self.async_streaming_chunks,
is_async=True,
)
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
await self.async_set_cache(

View file

@ -0,0 +1,40 @@
"""
Add the event loop to the cache key, to prevent event loop closed errors.
"""
import asyncio
from .in_memory_cache import InMemoryCache
class LLMClientCache(InMemoryCache):
def update_cache_key_with_event_loop(self, key):
"""
Add the event loop to the cache key, to prevent event loop closed errors.
If none, use the key as is.
"""
try:
event_loop = asyncio.get_event_loop()
stringified_event_loop = str(id(event_loop))
return f"{key}-{stringified_event_loop}"
except Exception: # handle no current event loop
return key
def set_cache(self, key, value, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return super().set_cache(key, value, **kwargs)
async def async_set_cache(self, key, value, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return await super().async_set_cache(key, value, **kwargs)
def get_cache(self, key, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return super().get_cache(key, **kwargs)
async def async_get_cache(self, key, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return await super().async_get_cache(key, **kwargs)

View file

@ -54,6 +54,7 @@ class RedisCache(BaseCache):
redis_flush_size: Optional[int] = 100,
namespace: Optional[str] = None,
startup_nodes: Optional[List] = None, # for redis-cluster
socket_timeout: Optional[float] = 5.0, # default 5 second timeout
**kwargs,
):
@ -70,6 +71,9 @@ class RedisCache(BaseCache):
redis_kwargs["password"] = password
if startup_nodes is not None:
redis_kwargs["startup_nodes"] = startup_nodes
if socket_timeout is not None:
redis_kwargs["socket_timeout"] = socket_timeout
### HEALTH MONITORING OBJECT ###
if kwargs.get("service_logger_obj", None) is not None and isinstance(
kwargs["service_logger_obj"], ServiceLogging
@ -556,6 +560,7 @@ class RedisCache(BaseCache):
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,

View file

@ -7,6 +7,7 @@ DEFAULT_MAX_RETRIES = 2
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
)
DEFAULT_REDIS_SYNC_INTERVAL = 1
DEFAULT_COOLDOWN_TIME_SECONDS = 5
DEFAULT_REPLICATE_POLLING_RETRIES = 5
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
@ -18,6 +19,7 @@ SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
#### Networking settings ####
request_timeout: float = 6000 # time in seconds
STREAM_SSE_DONE_STRING: str = "[DONE]"
LITELLM_CHAT_PROVIDERS = [
"openai",

View file

@ -44,7 +44,12 @@ from litellm.llms.vertex_ai.cost_calculator import cost_router as google_cost_ro
from litellm.llms.vertex_ai.image_generation.cost_calculator import (
cost_calculator as vertex_ai_image_cost_calculator,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.responses.utils import ResponseAPILoggingUtils
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
ResponseAPIUsage,
ResponsesAPIResponse,
)
from litellm.types.rerank import RerankBilledUnits, RerankResponse
from litellm.types.utils import (
CallTypesLiteral,
@ -464,6 +469,13 @@ def _get_usage_object(
return usage_obj
def _is_known_usage_objects(usage_obj):
"""Returns True if the usage obj is a known Usage type"""
return isinstance(usage_obj, litellm.Usage) or isinstance(
usage_obj, ResponseAPIUsage
)
def _infer_call_type(
call_type: Optional[CallTypesLiteral], completion_response: Any
) -> Optional[CallTypesLiteral]:
@ -573,9 +585,7 @@ def completion_cost( # noqa: PLR0915
base_model=base_model,
)
verbose_logger.debug(
f"completion_response _select_model_name_for_cost_calc: {model}"
)
verbose_logger.info(f"selected model name for cost calculation: {model}")
if completion_response is not None and (
isinstance(completion_response, BaseModel)
@ -587,8 +597,8 @@ def completion_cost( # noqa: PLR0915
)
else:
usage_obj = getattr(completion_response, "usage", {})
if isinstance(usage_obj, BaseModel) and not isinstance(
usage_obj, litellm.Usage
if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
usage_obj=usage_obj
):
setattr(
completion_response,
@ -601,6 +611,14 @@ def completion_cost( # noqa: PLR0915
_usage = usage_obj.model_dump()
else:
_usage = usage_obj
if ResponseAPILoggingUtils._is_response_api_usage(_usage):
_usage = (
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
_usage
).model_dump()
)
# get input/output tokens from completion_response
prompt_tokens = _usage.get("prompt_tokens", 0)
completion_tokens = _usage.get("completion_tokens", 0)
@ -790,6 +808,23 @@ def completion_cost( # noqa: PLR0915
raise e
def get_response_cost_from_hidden_params(
hidden_params: Union[dict, BaseModel]
) -> Optional[float]:
if isinstance(hidden_params, BaseModel):
_hidden_params_dict = hidden_params.model_dump()
else:
_hidden_params_dict = hidden_params
additional_headers = _hidden_params_dict.get("additional_headers", {})
if additional_headers and "x-litellm-response-cost" in additional_headers:
response_cost = additional_headers["x-litellm-response-cost"]
if response_cost is None:
return None
return float(additional_headers["x-litellm-response-cost"])
return None
def response_cost_calculator(
response_object: Union[
ModelResponse,
@ -799,6 +834,7 @@ def response_cost_calculator(
TextCompletionResponse,
HttpxBinaryResponseContent,
RerankResponse,
ResponsesAPIResponse,
],
model: str,
custom_llm_provider: Optional[str],
@ -825,7 +861,7 @@ def response_cost_calculator(
base_model: Optional[str] = None,
custom_pricing: Optional[bool] = None,
prompt: str = "",
) -> Optional[float]:
) -> float:
"""
Returns
- float or None: cost of response
@ -837,6 +873,14 @@ def response_cost_calculator(
else:
if isinstance(response_object, BaseModel):
response_object._hidden_params["optional_params"] = optional_params
if hasattr(response_object, "_hidden_params"):
provider_response_cost = get_response_cost_from_hidden_params(
response_object._hidden_params
)
if provider_response_cost is not None:
return provider_response_cost
response_cost = completion_cost(
completion_response=response_object,
model=model,

View file

@ -25,7 +25,7 @@ from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
)
from litellm.types.router import *
from litellm.utils import supports_httpx_timeout
from litellm.utils import get_litellm_params, supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI()
@ -546,6 +546,7 @@ def create_file(
try:
_is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -630,6 +631,7 @@ def create_file(
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""

View file

@ -1,31 +1,37 @@
import json
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
def set_attributes(span: Span, kwargs, response_obj):
from openinference.semconv.trace import (
from litellm.integrations._types.open_inference import (
MessageAttributes,
OpenInferenceSpanKindValues,
SpanAttributes,
)
try:
litellm_params = kwargs.get("litellm_params", {}) or {}
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
#############################################
############ LLM CALL METADATA ##############
#############################################
metadata = litellm_params.get("metadata", {}) or {}
span.set_attribute(SpanAttributes.METADATA, str(metadata))
if standard_logging_payload and (
metadata := standard_logging_payload["metadata"]
):
span.set_attribute(SpanAttributes.METADATA, safe_dumps(metadata))
#############################################
########## LLM Request Attributes ###########
@ -62,13 +68,12 @@ def set_attributes(span: Span, kwargs, response_obj):
msg.get("content", ""),
)
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload and (model_params := standard_logging_payload["model_parameters"]):
if standard_logging_payload and (
model_params := standard_logging_payload["model_parameters"]
):
# The Generative AI Provider: Azure, OpenAI, etc.
span.set_attribute(
SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(model_params)
SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params)
)
if model_params.get("user"):
@ -80,7 +85,7 @@ def set_attributes(span: Span, kwargs, response_obj):
########## LLM Response Attributes ##########
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
#############################################
if hasattr(response_obj, 'get'):
if hasattr(response_obj, "get"):
for choice in response_obj.get("choices", []):
response_message = choice.get("message", {})
span.set_attribute(

View file

@ -3,31 +3,38 @@ arize AI is OTEL compatible
this file has Arize ai specific helper functions
"""
import os
from typing import TYPE_CHECKING, Any
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm.integrations.arize import _utils
from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.types.integrations.arize import ArizeConfig
from litellm.types.services import ServiceLoggerPayload
if TYPE_CHECKING:
from litellm.types.integrations.arize import Protocol as _Protocol
from opentelemetry.trace import Span as _Span
from litellm.types.integrations.arize import Protocol as _Protocol
Protocol = _Protocol
Span = _Span
else:
Protocol = Any
Span = Any
class ArizeLogger:
class ArizeLogger(OpenTelemetry):
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
return
@staticmethod
def set_arize_attributes(span: Span, kwargs, response_obj):
_utils.set_attributes(span, kwargs, response_obj)
return
@staticmethod
def get_arize_config() -> ArizeConfig:
@ -43,11 +50,6 @@ class ArizeLogger:
space_key = os.environ.get("ARIZE_SPACE_KEY")
api_key = os.environ.get("ARIZE_API_KEY")
if not space_key:
raise ValueError("ARIZE_SPACE_KEY not found in environment variables")
if not api_key:
raise ValueError("ARIZE_API_KEY not found in environment variables")
grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
@ -55,13 +57,13 @@ class ArizeLogger:
protocol: Protocol = "otlp_grpc"
if grpc_endpoint:
protocol="otlp_grpc"
endpoint=grpc_endpoint
protocol = "otlp_grpc"
endpoint = grpc_endpoint
elif http_endpoint:
protocol="otlp_http"
endpoint=http_endpoint
protocol = "otlp_http"
endpoint = http_endpoint
else:
protocol="otlp_grpc"
protocol = "otlp_grpc"
endpoint = "https://otlp.arize.com/v1"
return ArizeConfig(
@ -71,4 +73,33 @@ class ArizeLogger:
endpoint=endpoint,
)
async def async_service_success_hook(
self,
payload: ServiceLoggerPayload,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
pass
async def async_service_failure_hook(
self,
payload: ServiceLoggerPayload,
error: Optional[str] = "",
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[float, datetime]] = None,
event_metadata: Optional[dict] = None,
):
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
pass
def create_litellm_proxy_request_started_span(
self,
start_time: datetime,
headers: dict,
):
"""Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
pass

View file

@ -1,7 +1,16 @@
#### What this does ####
# On success, logs events to Promptlayer
import traceback
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
List,
Literal,
Optional,
Tuple,
Union,
)
from pydantic import BaseModel
@ -14,6 +23,7 @@ from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
StandardCallbackDynamicParams,
StandardLoggingPayload,
)
@ -239,6 +249,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
) -> Any:
pass
@ -250,6 +261,15 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
) -> Any:
pass
async def async_post_call_streaming_iterator_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: Any,
request_data: dict,
) -> AsyncGenerator[ModelResponseStream, None]:
async for item in response:
yield item
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):

View file

@ -0,0 +1,49 @@
from typing import List, Optional, Tuple
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.prompt_management_base import (
PromptManagementBase,
PromptManagementClient,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
class CustomPromptManagement(CustomLogger, PromptManagementBase):
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Returns:
- model: str - the model to use (can be pulled from prompt management tool)
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
"""
return model, messages, non_default_params
@property
def integration_name(self) -> str:
return "custom-prompt-management"
def should_run_prompt_management(
self,
prompt_id: str,
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
return True
def _compile_prompt_helper(
self,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> PromptManagementClient:
raise NotImplementedError(
"Custom prompt management does not support compile prompt helper"
)

View file

@ -10,6 +10,7 @@ from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Function,
StandardCallbackDynamicParams,
StandardLoggingPayload,
)
@ -311,6 +312,8 @@ class OpenTelemetry(CustomLogger):
)
_parent_context, parent_otel_span = self._get_span_context(kwargs)
self._add_dynamic_span_processor_if_needed(kwargs)
# Span 1: Requst sent to litellm SDK
span = self.tracer.start_span(
name=self._get_span_name(kwargs),
@ -341,6 +344,45 @@ class OpenTelemetry(CustomLogger):
if parent_otel_span is not None:
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _add_dynamic_span_processor_if_needed(self, kwargs):
"""
Helper method to add a span processor with dynamic headers if needed.
This allows for per-request configuration of telemetry exporters by
extracting headers from standard_callback_dynamic_params.
"""
from opentelemetry import trace
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params")
)
if not standard_callback_dynamic_params:
return
# Extract headers from dynamic params
dynamic_headers = {}
# Handle Arize headers
if standard_callback_dynamic_params.get("arize_space_key"):
dynamic_headers["space_key"] = standard_callback_dynamic_params.get(
"arize_space_key"
)
if standard_callback_dynamic_params.get("arize_api_key"):
dynamic_headers["api_key"] = standard_callback_dynamic_params.get(
"arize_api_key"
)
# Only create a span processor if we have headers to use
if len(dynamic_headers) > 0:
from opentelemetry.sdk.trace import TracerProvider
provider = trace.get_tracer_provider()
if isinstance(provider, TracerProvider):
span_processor = self._get_span_processor(
dynamic_headers=dynamic_headers
)
provider.add_span_processor(span_processor)
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
@ -443,14 +485,12 @@ class OpenTelemetry(CustomLogger):
self, span: Span, kwargs, response_obj: Optional[Any]
):
try:
if self.callback_name == "arize":
from litellm.integrations.arize.arize import ArizeLogger
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
return
elif self.callback_name == "arize_phoenix":
if self.callback_name == "arize_phoenix":
from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger
ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj)
ArizePhoenixLogger.set_arize_phoenix_attributes(
span, kwargs, response_obj
)
return
elif self.callback_name == "langtrace":
from litellm.integrations.langtrace import LangtraceAttributes
@ -779,7 +819,7 @@ class OpenTelemetry(CustomLogger):
carrier = {"traceparent": traceparent}
return TraceContextTextMapPropagator().extract(carrier=carrier), None
def _get_span_processor(self):
def _get_span_processor(self, dynamic_headers: Optional[dict] = None):
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterGRPC,
)
@ -799,10 +839,9 @@ class OpenTelemetry(CustomLogger):
self.OTEL_ENDPOINT,
self.OTEL_HEADERS,
)
_split_otel_headers = {}
if self.OTEL_HEADERS is not None and isinstance(self.OTEL_HEADERS, str):
_split_otel_headers = self.OTEL_HEADERS.split("=")
_split_otel_headers = {_split_otel_headers[0]: _split_otel_headers[1]}
_split_otel_headers = OpenTelemetry._get_headers_dictionary(
headers=dynamic_headers or self.OTEL_HEADERS
)
if isinstance(self.OTEL_EXPORTER, SpanExporter):
verbose_logger.debug(
@ -844,6 +883,25 @@ class OpenTelemetry(CustomLogger):
)
return BatchSpanProcessor(ConsoleSpanExporter())
@staticmethod
def _get_headers_dictionary(headers: Optional[Union[str, dict]]) -> Dict[str, str]:
"""
Convert a string or dictionary of headers into a dictionary of headers.
"""
_split_otel_headers: Dict[str, str] = {}
if headers:
if isinstance(headers, str):
# when passed HEADERS="x-honeycomb-team=B85YgLm96******"
# Split only on first '=' occurrence
parts = headers.split("=", 1)
if len(parts) == 2:
_split_otel_headers = {parts[0]: parts[1]}
else:
_split_otel_headers = {}
elif isinstance(headers, dict):
_split_otel_headers = headers
return _split_otel_headers
async def async_management_endpoint_success_hook(
self,
logging_payload: ManagementEndpointLoggingPayload,
@ -948,3 +1006,18 @@ class OpenTelemetry(CustomLogger):
)
management_endpoint_span.set_status(Status(StatusCode.ERROR))
management_endpoint_span.end(end_time=_end_time_ns)
def create_litellm_proxy_request_started_span(
self,
start_time: datetime,
headers: dict,
) -> Optional[Span]:
"""
Create a span for the received proxy server request.
"""
return self.tracer.start_span(
name="Received Proxy Server Request",
start_time=self._to_ns(start_time),
context=self.get_traceparent_from_header(headers=headers),
kind=self.span_kind.SERVER,
)

View file

@ -0,0 +1,34 @@
"""Utils for accessing credentials."""
from typing import List
import litellm
from litellm.types.utils import CredentialItem
class CredentialAccessor:
@staticmethod
def get_credential_values(credential_name: str) -> dict:
"""Safe accessor for credentials."""
if not litellm.credential_list:
return {}
for credential in litellm.credential_list:
if credential.credential_name == credential_name:
return credential.credential_values.copy()
return {}
@staticmethod
def upsert_credentials(credentials: List[CredentialItem]):
"""Add a credential to the list of credentials."""
credential_names = [cred.credential_name for cred in litellm.credential_list]
for credential in credentials:
if credential.credential_name in credential_names:
# Find and replace the existing credential in the list
for i, existing_cred in enumerate(litellm.credential_list):
if existing_cred.credential_name == credential.credential_name:
litellm.credential_list[i] = credential
break
else:
litellm.credential_list.append(credential)

View file

@ -127,7 +127,7 @@ def exception_type( # type: ignore # noqa: PLR0915
completion_kwargs={},
extra_kwargs={},
):
"""Maps an LLM Provider Exception to OpenAI Exception Format"""
if any(
isinstance(original_exception, exc_type)
for exc_type in litellm.LITELLM_EXCEPTION_TYPES

View file

@ -58,6 +58,8 @@ def get_litellm_params(
async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
merge_reasoning_content_in_choices: Optional[bool] = None,
api_version: Optional[str] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> dict:
litellm_params = {
@ -99,5 +101,14 @@ def get_litellm_params(
"async_call": async_call,
"ssl_verify": ssl_verify,
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
"api_version": api_version,
"azure_ad_token": kwargs.get("azure_ad_token"),
"tenant_id": kwargs.get("tenant_id"),
"client_id": kwargs.get("client_id"),
"client_secret": kwargs.get("client_secret"),
"azure_username": kwargs.get("azure_username"),
"azure_password": kwargs.get("azure_password"),
"max_retries": max_retries,
"timeout": kwargs.get("timeout"),
}
return litellm_params

View file

@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915
model, custom_llm_provider
)
if custom_llm_provider:
if (
model.split("/")[0] == custom_llm_provider
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
model = model.replace("{}/".format(custom_llm_provider), "")
return model, custom_llm_provider, dynamic_api_key, api_base
if custom_llm_provider and (
model.split("/")[0] != custom_llm_provider
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
model = custom_llm_provider + "/" + model
if api_key and api_key.startswith("os.environ/"):
dynamic_api_key = get_secret_str(api_key)
# check if llm provider part of model name
if (
model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list_set
@ -571,6 +569,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
or "https://api.galadriel.com/v1"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
elif custom_llm_provider == "snowflake":
api_base = (
api_base
or get_secret_str("SNOWFLAKE_API_BASE")
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
if api_base is not None and not isinstance(api_base, str):
raise Exception("api base needs to be a string. api_base={}".format(api_base))
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):

View file

@ -29,6 +29,7 @@ from litellm.batches.batch_utils import _handle_completed_batch
from litellm.caching.caching import DualCache, InMemoryCache
from litellm.caching.caching_handler import LLMCachingHandler
from litellm.cost_calculator import _select_model_name_for_cost_calc
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.mlflow import MlflowLogger
@ -39,11 +40,14 @@ from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_custom_logger,
redact_message_input_output_from_logging,
)
from litellm.responses.utils import ResponseAPILoggingUtils
from litellm.types.llms.openai import (
AllMessageValues,
Batch,
FineTuningJob,
HttpxBinaryResponseContent,
ResponseCompletedEvent,
ResponsesAPIResponse,
)
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
@ -73,11 +77,11 @@ from litellm.types.utils import (
from litellm.utils import _get_base_model_from_metadata, executor, print_verbose
from ..integrations.argilla import ArgillaLogger
from ..integrations.arize.arize import ArizeLogger
from ..integrations.arize.arize_phoenix import ArizePhoenixLogger
from ..integrations.athina import AthinaLogger
from ..integrations.azure_storage.azure_storage import AzureBlobStorageLogger
from ..integrations.braintrust_logging import BraintrustLogger
from ..integrations.custom_prompt_management import CustomPromptManagement
from ..integrations.datadog.datadog import DataDogLogger
from ..integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger
from ..integrations.dynamodb import DyanmoDBLogger
@ -107,7 +111,6 @@ from .exception_mapping_utils import _get_response_headers
from .initialize_dynamic_callback_params import (
initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params,
)
from .logging_utils import _assemble_complete_response_from_streaming_chunks
from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache
try:
@ -427,34 +430,58 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_variables: Optional[dict],
) -> Tuple[str, List[AllMessageValues], dict]:
for (
custom_logger_compatible_callback
) in litellm._known_custom_logger_compatible_callbacks:
if model.startswith(custom_logger_compatible_callback):
custom_logger = self.get_custom_logger_for_prompt_management(model)
if custom_logger:
model, messages, non_default_params = (
custom_logger.get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params,
)
)
self.messages = messages
return model, messages, non_default_params
def get_custom_logger_for_prompt_management(
self, model: str
) -> Optional[CustomLogger]:
"""
Get a custom logger for prompt management based on model name or available callbacks.
Args:
model: The model name to check for prompt management integration
Returns:
A CustomLogger instance if one is found, None otherwise
"""
# First check if model starts with a known custom logger compatible callback
for callback_name in litellm._known_custom_logger_compatible_callbacks:
if model.startswith(callback_name):
custom_logger = _init_custom_logger_compatible_class(
logging_integration=custom_logger_compatible_callback,
logging_integration=callback_name,
internal_usage_cache=None,
llm_router=None,
)
if custom_logger is not None:
self.model_call_details["prompt_integration"] = model.split("/")[0]
return custom_logger
if custom_logger is None:
continue
old_name = model
# Then check for any registered CustomPromptManagement loggers
prompt_management_loggers = (
litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=CustomPromptManagement
)
)
model, messages, non_default_params = (
custom_logger.get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params,
)
)
self.model_call_details["prompt_integration"] = old_name.split("/")[0]
self.messages = messages
if prompt_management_loggers:
logger = prompt_management_loggers[0]
self.model_call_details["prompt_integration"] = logger.__class__.__name__
return logger
return model, messages, non_default_params
return None
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
if data is None:
@ -716,25 +743,9 @@ class Logging(LiteLLMLoggingBaseClass):
Masks the headers of the request sent from LiteLLM
"""
sensitive_keywords = [
"authorization",
"token",
"key",
"secret",
]
return {
k: (
(v[:-44] + "*" * 44)
if (isinstance(v, str) and len(v) > 44)
else "*****"
)
for k, v in headers.items()
if not ignore_sensitive_headers
or not any(
sensitive_keyword in k.lower()
for sensitive_keyword in sensitive_keywords
)
}
return _get_masked_values(
headers, ignore_sensitive_values=ignore_sensitive_headers
)
def post_call(
self, original_response, input=None, api_key=None, additional_args={}
@ -851,6 +862,8 @@ class Logging(LiteLLMLoggingBaseClass):
RerankResponse,
Batch,
FineTuningJob,
ResponsesAPIResponse,
ResponseCompletedEvent,
],
cache_hit: Optional[bool] = None,
) -> Optional[float]:
@ -1000,7 +1013,7 @@ class Logging(LiteLLMLoggingBaseClass):
standard_logging_object is None
and result is not None
and self.stream is not True
): # handle streaming separately
):
if (
isinstance(result, ModelResponse)
or isinstance(result, ModelResponseStream)
@ -1012,6 +1025,7 @@ class Logging(LiteLLMLoggingBaseClass):
or isinstance(result, RerankResponse)
or isinstance(result, FineTuningJob)
or isinstance(result, LiteLLMBatch)
or isinstance(result, ResponsesAPIResponse)
):
## HIDDEN PARAMS ##
hidden_params = getattr(result, "_hidden_params", {})
@ -1111,7 +1125,7 @@ class Logging(LiteLLMLoggingBaseClass):
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
] = None
if "complete_streaming_response" in self.model_call_details:
return # break out of this.
@ -1633,7 +1647,7 @@ class Logging(LiteLLMLoggingBaseClass):
if "async_complete_streaming_response" in self.model_call_details:
return # break out of this.
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
] = self._get_assembled_streaming_response(
result=result,
start_time=start_time,
@ -2343,28 +2357,24 @@ class Logging(LiteLLMLoggingBaseClass):
def _get_assembled_streaming_response(
self,
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream, Any],
result: Union[
ModelResponse,
TextCompletionResponse,
ModelResponseStream,
ResponseCompletedEvent,
Any,
],
start_time: datetime.datetime,
end_time: datetime.datetime,
is_async: bool,
streaming_chunks: List[Any],
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
) -> Optional[Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]]:
if isinstance(result, ModelResponse):
return result
elif isinstance(result, TextCompletionResponse):
return result
elif isinstance(result, ModelResponseStream):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=streaming_chunks,
is_async=is_async,
)
return complete_streaming_response
elif isinstance(result, ResponseCompletedEvent):
return result.response
return None
def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse:
@ -2399,6 +2409,58 @@ class Logging(LiteLLMLoggingBaseClass):
return result
def _get_masked_values(
sensitive_object: dict,
ignore_sensitive_values: bool = False,
mask_all_values: bool = False,
unmasked_length: int = 4,
number_of_asterisks: Optional[int] = 4,
) -> dict:
"""
Internal debugging helper function
Masks the headers of the request sent from LiteLLM
Args:
masked_length: Optional length for the masked portion (number of *). If set, will use exactly this many *
regardless of original string length. The total length will be unmasked_length + masked_length.
"""
sensitive_keywords = [
"authorization",
"token",
"key",
"secret",
]
return {
k: (
(
v[: unmasked_length // 2]
+ "*" * number_of_asterisks
+ v[-unmasked_length // 2 :]
)
if (
isinstance(v, str)
and len(v) > unmasked_length
and number_of_asterisks is not None
)
else (
(
v[: unmasked_length // 2]
+ "*" * (len(v) - unmasked_length)
+ v[-unmasked_length // 2 :]
)
if (isinstance(v, str) and len(v) > unmasked_length)
else "*****"
)
)
for k, v in sensitive_object.items()
if not ignore_sensitive_values
or not any(
sensitive_keyword in k.lower() for sensitive_keyword in sensitive_keywords
)
}
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
"""
Globally sets the callback client
@ -2621,13 +2683,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
)
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
isinstance(callback, ArizeLogger)
and callback.callback_name == "arize"
):
return callback # type: ignore
_otel_logger = OpenTelemetry(config=otel_config, callback_name="arize")
_in_memory_loggers.append(_otel_logger)
return _otel_logger # type: ignore
_arize_otel_logger = ArizeLogger(config=otel_config, callback_name="arize")
_in_memory_loggers.append(_arize_otel_logger)
return _arize_otel_logger # type: ignore
elif logging_integration == "arize_phoenix":
from litellm.integrations.opentelemetry import (
OpenTelemetry,
@ -2860,15 +2922,13 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
if isinstance(callback, OpenTelemetry):
return callback
elif logging_integration == "arize":
from litellm.integrations.opentelemetry import OpenTelemetry
if "ARIZE_SPACE_KEY" not in os.environ:
raise ValueError("ARIZE_SPACE_KEY not found in environment variables")
if "ARIZE_API_KEY" not in os.environ:
raise ValueError("ARIZE_API_KEY not found in environment variables")
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
isinstance(callback, ArizeLogger)
and callback.callback_name == "arize"
):
return callback
@ -3111,6 +3171,12 @@ class StandardLoggingPayloadSetup:
elif isinstance(usage, Usage):
return usage
elif isinstance(usage, dict):
if ResponseAPILoggingUtils._is_response_api_usage(usage):
return (
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
usage
)
)
return Usage(**usage)
raise ValueError(f"usage is required, got={usage} of type {type(usage)}")
@ -3215,6 +3281,7 @@ class StandardLoggingPayloadSetup:
additional_headers=None,
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
)
if hidden_params is not None:
for key in StandardLoggingHiddenParams.__annotations__.keys():
@ -3329,6 +3396,7 @@ def get_standard_logging_object_payload(
response_cost=None,
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
)
)
@ -3614,6 +3682,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
additional_headers=None,
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
)
# Convert numeric values to appropriate types

View file

@ -44,6 +44,7 @@ class ResponseMetadata:
"additional_headers": process_response_headers(
self._get_value_from_hidden_params("additional_headers") or {}
),
"litellm_model_name": model,
}
self._update_hidden_params(new_params)

View file

@ -1,4 +1,4 @@
from typing import Callable, List, Set, Union
from typing import Callable, List, Set, Type, Union
import litellm
from litellm._logging import verbose_logger
@ -86,21 +86,20 @@ class LoggingCallbackManager:
callback=callback, parent_list=litellm._async_failure_callback
)
def remove_callback_from_list_by_object(
self, callback_list, obj
):
def remove_callback_from_list_by_object(self, callback_list, obj):
"""
Remove callbacks that are methods of a particular object (e.g., router cleanup)
"""
if not isinstance(callback_list, list): # Not list -> do nothing
if not isinstance(callback_list, list): # Not list -> do nothing
return
remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj]
remove_list = [
c for c in callback_list if hasattr(c, "__self__") and c.__self__ == obj
]
for c in remove_list:
callback_list.remove(c)
def _add_string_callback_to_list(
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
):
@ -254,3 +253,11 @@ class LoggingCallbackManager:
):
matched_callbacks.add(callback)
return matched_callbacks
def get_custom_loggers_for_type(
self, callback_type: Type[CustomLogger]
) -> List[CustomLogger]:
"""
Get all custom loggers that are instances of the given class type
"""
return [c for c in self._get_all_callbacks() if isinstance(c, callback_type)]

View file

@ -77,6 +77,10 @@ def _assemble_complete_response_from_streaming_chunks(
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = None
if isinstance(result, ModelResponse):
return result
if result.choices[0].finish_reason is not None: # if it's the last chunk
streaming_chunks.append(result)
try:

View file

@ -77,6 +77,16 @@ def convert_content_list_to_str(message: AllMessageValues) -> str:
return texts
def get_str_from_messages(messages: List[AllMessageValues]) -> str:
"""
Converts a list of messages to a string
"""
text = ""
for message in messages:
text += convert_content_list_to_str(message=message)
return text
def is_non_content_values_set(message: AllMessageValues) -> bool:
ignore_keys = ["content", "role", "name"]
return any(

View file

@ -166,148 +166,108 @@ def convert_to_ollama_image(openai_image_url: str):
)
def _handle_ollama_system_message(
messages: list, prompt: str, msg_i: int
) -> Tuple[str, int]:
system_content_str = ""
## MERGE CONSECUTIVE SYSTEM CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "system":
msg_content = convert_content_list_to_str(messages[msg_i])
system_content_str += msg_content
msg_i += 1
return system_content_str, msg_i
def ollama_pt(
model, messages
model: str, messages: list
) -> Union[
str, OllamaVisionModelObject
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model:
prompt = custom_prompt(
role_dict={
"system": {"pre_message": "### System:\n", "post_message": "\n"},
"user": {
"pre_message": "### User:\n",
"post_message": "\n",
},
"assistant": {
"pre_message": "### Response:\n",
"post_message": "\n",
},
},
final_prompt_value="### Response:",
messages=messages,
user_message_types = {"user", "tool", "function"}
msg_i = 0
images = []
prompt = ""
while msg_i < len(messages):
init_msg_i = msg_i
user_content_str = ""
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
msg_i += 1
if user_content_str:
prompt += f"### User:\n{user_content_str}\n\n"
system_content_str, msg_i = _handle_ollama_system_message(
messages, prompt, msg_i
)
else:
user_message_types = {"user", "tool", "function"}
msg_i = 0
images = []
prompt = ""
while msg_i < len(messages):
init_msg_i = msg_i
user_content_str = ""
## MERGE CONSECUTIVE USER CONTENT ##
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
if system_content_str:
prompt += f"### System:\n{system_content_str}\n\n"
msg_i += 1
assistant_content_str = ""
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_content_str += convert_content_list_to_str(messages[msg_i])
msg_i += 1
if user_content_str:
prompt += f"### User:\n{user_content_str}\n\n"
tool_calls = messages[msg_i].get("tool_calls")
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
assistant_content_str = ""
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "text":
assistant_content_str += m["text"]
elif isinstance(msg_content, str):
# Tool message content will always be a string
assistant_content_str += msg_content
tool_calls = messages[msg_i].get("tool_calls")
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
ollama_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
ollama_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
msg_i += 1
if assistant_content_str:
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
if msg_i == init_msg_i: # prevent infinite loops
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider="ollama",
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
)
# prompt = ""
# images = []
# for message in messages:
# if isinstance(message["content"], str):
# prompt += message["content"]
# elif isinstance(message["content"], list):
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
# for element in message["content"]:
# if isinstance(element, dict):
# if element["type"] == "text":
# prompt += element["text"]
# elif element["type"] == "image_url":
# base64_image = convert_to_ollama_image(
# element["image_url"]["url"]
# )
# images.append(base64_image)
# if "tool_calls" in message:
# tool_calls = []
msg_i += 1
# for call in message["tool_calls"]:
# call_id: str = call["id"]
# function_name: str = call["function"]["name"]
# arguments = json.loads(call["function"]["arguments"])
if assistant_content_str:
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
# tool_calls.append(
# {
# "id": call_id,
# "type": "function",
# "function": {"name": function_name, "arguments": arguments},
# }
# )
if msg_i == init_msg_i: # prevent infinite loops
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider="ollama",
)
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
response_dict: OllamaVisionModelObject = {
"prompt": prompt,
"images": images,
}
# elif "tool_call_id" in message:
# prompt += f"### User:\n{message['content']}\n\n"
return {"prompt": prompt, "images": images}
return prompt
return response_dict
def mistral_instruct_pt(messages):

View file

@ -13,6 +13,7 @@ from litellm.types.utils import (
Function,
FunctionCall,
ModelResponse,
ModelResponseStream,
PromptTokensDetails,
Usage,
)
@ -319,8 +320,12 @@ class ChunkProcessor:
usage_chunk: Optional[Usage] = None
if "usage" in chunk:
usage_chunk = chunk["usage"]
elif isinstance(chunk, ModelResponse) and hasattr(chunk, "_hidden_params"):
elif (
isinstance(chunk, ModelResponse)
or isinstance(chunk, ModelResponseStream)
) and hasattr(chunk, "_hidden_params"):
usage_chunk = chunk._hidden_params.get("usage", None)
if usage_chunk is not None:
usage_chunk_dict = self._usage_chunk_calculation_helper(usage_chunk)
if (

View file

@ -898,6 +898,8 @@ class CustomStreamWrapper:
return model_response
# Default - return StopIteration
if hasattr(model_response, "usage"):
self.chunks.append(model_response)
raise StopIteration
# flush any remaining holding chunk
if len(self.holding_chunk) > 0:
@ -1470,6 +1472,24 @@ class CustomStreamWrapper:
"""
self.logging_loop = loop
def cache_streaming_response(self, processed_chunk, cache_hit: bool):
"""
Caches the streaming response
"""
if not cache_hit and self.logging_obj._llm_caching_handler is not None:
self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache(
processed_chunk
)
async def async_cache_streaming_response(self, processed_chunk, cache_hit: bool):
"""
Caches the streaming response
"""
if not cache_hit and self.logging_obj._llm_caching_handler is not None:
await self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
processed_chunk
)
def run_success_logging_and_cache_storage(self, processed_chunk, cache_hit: bool):
"""
Runs success logging in a thread and adds the response to the cache
@ -1501,12 +1521,6 @@ class CustomStreamWrapper:
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk, None, None, cache_hit)
## Sync store in cache
if self.logging_obj._llm_caching_handler is not None:
self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache(
processed_chunk
)
def finish_reason_handler(self):
model_response = self.model_response_creator()
_finish_reason = self.received_finish_reason or self.intermittent_finish_reason
@ -1553,10 +1567,11 @@ class CustomStreamWrapper:
if response is None:
continue
## LOGGING
threading.Thread(
target=self.run_success_logging_and_cache_storage,
args=(response, cache_hit),
).start() # log response
executor.submit(
self.run_success_logging_and_cache_storage,
response,
cache_hit,
) # log response
choice = response.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
@ -1600,13 +1615,27 @@ class CustomStreamWrapper:
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
self.cache_streaming_response(
processed_chunk=complete_streaming_response.model_copy(
deep=True
),
cache_hit=cache_hit,
)
executor.submit(
self.logging_obj.success_handler,
complete_streaming_response.model_copy(deep=True),
None,
None,
cache_hit,
)
else:
executor.submit(
self.logging_obj.success_handler,
response,
None,
None,
cache_hit,
)
if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True
return response
@ -1618,10 +1647,11 @@ class CustomStreamWrapper:
usage = calculate_total_usage(chunks=self.chunks)
processed_chunk._hidden_params["usage"] = usage
## LOGGING
threading.Thread(
target=self.run_success_logging_and_cache_storage,
args=(processed_chunk, cache_hit),
).start() # log response
executor.submit(
self.run_success_logging_and_cache_storage,
processed_chunk,
cache_hit,
) # log response
return processed_chunk
except Exception as e:
traceback_exception = traceback.format_exc()
@ -1690,13 +1720,6 @@ class CustomStreamWrapper:
if processed_chunk is None:
continue
if self.logging_obj._llm_caching_handler is not None:
asyncio.create_task(
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
processed_chunk=cast(ModelResponse, processed_chunk),
)
)
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
@ -1767,6 +1790,14 @@ class CustomStreamWrapper:
"usage",
getattr(complete_streaming_response, "usage"),
)
asyncio.create_task(
self.async_cache_streaming_response(
processed_chunk=complete_streaming_response.model_copy(
deep=True
),
cache_hit=cache_hit,
)
)
if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True
return response

View file

@ -29,6 +29,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
api_base: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""

View file

@ -1,4 +1,4 @@
from typing import Coroutine, Iterable, Literal, Optional, Union
from typing import Any, Coroutine, Dict, Iterable, Literal, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI
@ -18,10 +18,10 @@ from ...types.llms.openai import (
SyncCursorPage,
Thread,
)
from ..base import BaseLLM
from .common_utils import BaseAzureLLM
class AzureAssistantsAPI(BaseLLM):
class AzureAssistantsAPI(BaseAzureLLM):
def __init__(self) -> None:
super().__init__()
@ -34,18 +34,18 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AzureOpenAI(**data) # type: ignore
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
is_async=False,
)
azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_openai_client = client
@ -60,18 +60,19 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AsyncAzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AsyncAzureOpenAI(**data)
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
is_async=True,
)
azure_openai_client = AsyncAzureOpenAI(**azure_client_params)
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
else:
azure_openai_client = client
@ -89,6 +90,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> AsyncCursorPage[Assistant]:
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -98,6 +100,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.list()
@ -146,6 +149,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_assistants=None,
litellm_params: Optional[dict] = None,
):
if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants(
@ -156,6 +160,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -165,6 +170,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
api_version=api_version,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.list()
@ -184,6 +190,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> OpenAIMessage:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -193,6 +200,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
@ -222,6 +230,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
a_add_message: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, OpenAIMessage]:
...
@ -238,6 +247,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
a_add_message: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> OpenAIMessage:
...
@ -255,6 +265,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
a_add_message: Optional[bool] = None,
litellm_params: Optional[dict] = None,
):
if a_add_message is not None and a_add_message is True:
return self.a_add_message(
@ -267,6 +278,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -300,6 +312,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -309,6 +322,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
@ -329,6 +343,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_messages: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@ -344,6 +359,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_messages: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> SyncCursorPage[OpenAIMessage]:
...
@ -360,6 +376,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_messages=None,
litellm_params: Optional[dict] = None,
):
if aget_messages is not None and aget_messages is True:
return self.async_get_messages(
@ -371,6 +388,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -380,6 +398,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
@ -399,6 +418,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
litellm_params: Optional[dict] = None,
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -408,6 +428,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
data = {}
@ -435,6 +456,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncAzureOpenAI],
acreate_thread: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, Thread]:
...
@ -451,6 +473,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AzureOpenAI],
acreate_thread: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> Thread:
...
@ -468,6 +491,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
litellm_params: Optional[dict] = None,
):
"""
Here's an example:
@ -490,6 +514,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
messages=messages,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -499,6 +524,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
data = {}
@ -521,6 +547,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -530,6 +557,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
@ -550,6 +578,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_thread: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, Thread]:
...
@ -565,6 +594,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_thread: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> Thread:
...
@ -581,6 +611,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_thread=None,
litellm_params: Optional[dict] = None,
):
if aget_thread is not None and aget_thread is True:
return self.async_get_thread(
@ -592,6 +623,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -601,6 +633,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
@ -618,7 +651,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
metadata: Optional[Dict],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
@ -629,6 +662,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> Run:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -638,6 +672,7 @@ class AzureAssistantsAPI(BaseLLM):
api_version=api_version,
azure_ad_token=azure_ad_token,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
@ -645,7 +680,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
tools=tools,
)
@ -659,12 +694,13 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
metadata: Optional[Dict],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
litellm_params: Optional[dict] = None,
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
data: Dict[str, Any] = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
@ -684,12 +720,13 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
metadata: Optional[Dict],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
litellm_params: Optional[dict] = None,
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
data: Dict[str, Any] = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
@ -711,7 +748,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
metadata: Optional[Dict],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
@ -733,7 +770,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
metadata: Optional[Dict],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
@ -756,7 +793,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
metadata: Optional[Dict],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
@ -769,6 +806,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
litellm_params: Optional[dict] = None,
):
if arun_thread is not None and arun_thread is True:
if stream is not None and stream is True:
@ -780,6 +818,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
return self.async_run_thread_stream(
client=azure_client,
@ -791,13 +830,14 @@ class AzureAssistantsAPI(BaseLLM):
model=model,
tools=tools,
event_handler=event_handler,
litellm_params=litellm_params,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
stream=stream,
tools=tools,
@ -808,6 +848,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -817,6 +858,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
if stream is not None and stream is True:
@ -830,6 +872,7 @@ class AzureAssistantsAPI(BaseLLM):
model=model,
tools=tools,
event_handler=event_handler,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
@ -837,7 +880,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
tools=tools,
)
@ -855,6 +898,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
create_assistant_data: dict,
litellm_params: Optional[dict] = None,
) -> Assistant:
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -864,6 +908,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.create(
@ -882,6 +927,7 @@ class AzureAssistantsAPI(BaseLLM):
create_assistant_data: dict,
client=None,
async_create_assistants=None,
litellm_params: Optional[dict] = None,
):
if async_create_assistants is not None and async_create_assistants is True:
return self.async_create_assistants(
@ -893,6 +939,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
create_assistant_data=create_assistant_data,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -902,6 +949,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
@ -918,6 +966,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
assistant_id: str,
litellm_params: Optional[dict] = None,
):
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -927,6 +976,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.delete(
@ -945,6 +995,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
async_delete_assistants: Optional[bool] = None,
client=None,
litellm_params: Optional[dict] = None,
):
if async_delete_assistants is not None and async_delete_assistants is True:
return self.async_delete_assistant(
@ -956,6 +1007,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
assistant_id=assistant_id,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -965,6 +1017,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)

View file

@ -1,20 +1,20 @@
import uuid
from typing import Any, Optional
from typing import Any, Coroutine, Optional, Union
from openai import AsyncAzureOpenAI, AzureOpenAI
from pydantic import BaseModel
import litellm
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
from litellm.types.utils import FileTypes
from litellm.utils import TranscriptionResponse, convert_to_model_response_object
from .azure import (
AzureChatCompletion,
get_azure_ad_token_from_oidc,
select_azure_base_url_or_endpoint,
from litellm.utils import (
TranscriptionResponse,
convert_to_model_response_object,
extract_duration_from_srt_or_vtt,
)
from .azure import AzureChatCompletion
from .common_utils import AzureOpenAIError
class AzureAudioTranscription(AzureChatCompletion):
def audio_transcriptions(
@ -32,32 +32,12 @@ class AzureAudioTranscription(AzureChatCompletion):
client=None,
azure_ad_token: Optional[str] = None,
atranscription: bool = False,
) -> TranscriptionResponse:
litellm_params: Optional[dict] = None,
) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:
data = {"model": model, "file": audio_file, **optional_params}
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if atranscription is True:
return self.async_audio_transcriptions( # type: ignore
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
@ -65,14 +45,26 @@ class AzureAudioTranscription(AzureChatCompletion):
api_key=api_key,
api_base=api_base,
client=client,
azure_client_params=azure_client_params,
max_retries=max_retries,
logging_obj=logging_obj,
model=model,
litellm_params=litellm_params,
)
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=False,
client=client,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
if client is None:
azure_client = AzureOpenAI(http_client=litellm.client_session, **azure_client_params) # type: ignore
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
@ -109,25 +101,34 @@ class AzureAudioTranscription(AzureChatCompletion):
async def async_audio_transcriptions(
self,
audio_file: FileTypes,
model: str,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
azure_client_params: dict,
logging_obj: Any,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
):
litellm_params: Optional[dict] = None,
) -> TranscriptionResponse:
response = None
try:
if client is None:
async_azure_client = AsyncAzureOpenAI(
**azure_client_params,
http_client=litellm.aclient_session,
async_azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=True,
client=client,
litellm_params=litellm_params,
)
if not isinstance(async_azure_client, AsyncAzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="async_azure_client is not an instance of AsyncAzureOpenAI",
)
else:
async_azure_client = client
## LOGGING
logging_obj.pre_call(
@ -156,6 +157,8 @@ class AzureAudioTranscription(AzureChatCompletion):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
duration = extract_duration_from_srt_or_vtt(response)
stringified_response["duration"] = duration
## LOGGING
logging_obj.post_call(
@ -178,7 +181,12 @@ class AzureAudioTranscription(AzureChatCompletion):
model_response_object=model_response,
hidden_params=hidden_params,
response_type="audio_transcription",
) # type: ignore
)
if not isinstance(response, TranscriptionResponse):
raise AzureOpenAIError(
status_code=500,
message="response is not an instance of TranscriptionResponse",
)
return response
except Exception as e:
## LOGGING

View file

@ -1,16 +1,15 @@
import asyncio
import json
import os
import time
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
import httpx # type: ignore
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
import litellm
from litellm.caching.caching import DualCache
from litellm.constants import DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -25,15 +24,18 @@ from litellm.types.utils import (
from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
get_secret,
modify_url,
)
from ...types.llms.openai import HttpxBinaryResponseContent
from ..base import BaseLLM
from .common_utils import AzureOpenAIError, process_azure_headers
azure_ad_cache = DualCache()
from .common_utils import (
AzureOpenAIError,
BaseAzureLLM,
get_azure_ad_token_from_oidc,
process_azure_headers,
select_azure_base_url_or_endpoint,
)
class AzureOpenAIAssistantsAPIConfig:
@ -98,93 +100,6 @@ class AzureOpenAIAssistantsAPIConfig:
return optional_params
def select_azure_base_url_or_endpoint(azure_client_params: dict):
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv(
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
)
if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
azure_ad_token_cache_key = json.dumps(
{
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
}
)
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
return azure_ad_token_access_token
client = litellm.module_level_client
req_token = client.post(
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
azure_ad_token_json = req_token.json()
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
if azure_ad_token_access_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token access_token not returned"
)
if azure_ad_token_expires_in is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token expires_in not returned"
)
azure_ad_cache.set_cache(
key=azure_ad_token_cache_key,
value=azure_ad_token_access_token,
ttl=azure_ad_token_expires_in,
)
return azure_ad_token_access_token
def _check_dynamic_azure_params(
azure_client_params: dict,
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
@ -206,7 +121,7 @@ def _check_dynamic_azure_params(
return False
class AzureChatCompletion(BaseLLM):
class AzureChatCompletion(BaseAzureLLM, BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -226,52 +141,6 @@ class AzureChatCompletion(BaseLLM):
return headers
def _get_sync_azure_client(
self,
api_version: Optional[str],
api_base: Optional[str],
api_key: Optional[str],
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable],
model: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
client: Optional[Any],
client_type: Literal["sync", "async"],
):
# init AzureOpenAI Client
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None:
if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
elif client_type == "async":
azure_client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
return azure_client
def make_sync_azure_openai_chat_completion_request(
self,
azure_client: AzureOpenAI,
@ -294,11 +163,13 @@ class AzureChatCompletion(BaseLLM):
except Exception as e:
raise e
@track_llm_api_timing()
async def make_azure_openai_chat_completion_request(
self,
azure_client: AsyncAzureOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
):
"""
Helper to:
@ -360,37 +231,18 @@ class AzureChatCompletion(BaseLLM):
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if client is None:
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
azure_client_params = {
"api_version": api_version,
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token
)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
else:
client = AzureOpenAI(**azure_client_params)
client = self._init_azure_client_for_cloudflare_ai_gateway(
api_base=api_base,
model=model,
api_version=api_version,
max_retries=max_retries,
timeout=timeout,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
acompletion=acompletion,
client=client,
)
data = {"model": None, "messages": messages, **optional_params}
else:
@ -417,6 +269,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
litellm_params=litellm_params,
)
else:
return self.acompletion(
@ -434,6 +287,7 @@ class AzureChatCompletion(BaseLLM):
logging_obj=logging_obj,
max_retries=max_retries,
convert_tool_call_to_json_mode=json_mode,
litellm_params=litellm_params,
)
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
@ -449,6 +303,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
litellm_params=litellm_params,
)
else:
## LOGGING
@ -470,43 +325,15 @@ class AzureChatCompletion(BaseLLM):
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
client=client,
_is_async=False,
litellm_params=litellm_params,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
if (
client is None
or not isinstance(client, AzureOpenAI)
or dynamic_params
):
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault(
"api-version", api_version
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
@ -566,36 +393,22 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider: Optional[Callable] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI
litellm_params: Optional[dict] = {},
):
response = None
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.aclient_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# setting Azure client
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
client=client,
_is_async=True,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@ -615,6 +428,7 @@ class AzureChatCompletion(BaseLLM):
azure_client=azure_client,
data=data,
timeout=timeout,
logging_obj=logging_obj,
)
logging_obj.model_call_details["response_headers"] = headers
@ -680,6 +494,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
litellm_params: Optional[dict] = {},
):
# init AzureOpenAI Client
azure_client_params = {
@ -702,10 +517,20 @@ class AzureChatCompletion(BaseLLM):
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
client=client,
_is_async=False,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@ -747,32 +572,21 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
litellm_params: Optional[dict] = {},
):
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.aclient_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
client=client,
_is_async=True,
litellm_params=litellm_params,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@ -792,6 +606,7 @@ class AzureChatCompletion(BaseLLM):
azure_client=azure_client,
data=data,
timeout=timeout,
logging_obj=logging_obj,
)
logging_obj.model_call_details["response_headers"] = headers
@ -822,21 +637,36 @@ class AzureChatCompletion(BaseLLM):
async def aembedding(
self,
model: str,
data: dict,
model_response: EmbeddingResponse,
azure_client_params: dict,
input: list,
logging_obj: LiteLLMLoggingObj,
api_base: str,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[AsyncAzureOpenAI] = None,
timeout=None,
):
timeout: Optional[Union[float, httpx.Timeout]] = None,
max_retries: Optional[int] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
litellm_params: Optional[dict] = {},
) -> EmbeddingResponse:
response = None
try:
if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else:
openai_aclient = client
openai_aclient = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=True,
client=client,
litellm_params=litellm_params,
)
if not isinstance(openai_aclient, AsyncAzureOpenAI):
raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
raw_response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
)
@ -850,13 +680,19 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(
embedding_response = convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=process_azure_headers(headers),
response_type="embedding",
)
if not isinstance(embedding_response, EmbeddingResponse):
raise AzureOpenAIError(
status_code=500,
message="embedding_response is not an instance of EmbeddingResponse",
)
return embedding_response
except Exception as e:
## LOGGING
logging_obj.post_call(
@ -884,7 +720,8 @@ class AzureChatCompletion(BaseLLM):
client=None,
aembedding=None,
headers: Optional[dict] = None,
) -> EmbeddingResponse:
litellm_params: Optional[dict] = None,
) -> Union[EmbeddingResponse, Coroutine[Any, Any, EmbeddingResponse]]:
if headers:
optional_params["extra_headers"] = headers
if self._client_session is None:
@ -893,35 +730,6 @@ class AzureChatCompletion(BaseLLM):
data = {"model": model, "input": input, **optional_params}
if max_retries is None:
max_retries = litellm.DEFAULT_MAX_RETRIES
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if aembedding:
azure_client_params["http_client"] = litellm.aclient_session
else:
azure_client_params["http_client"] = litellm.client_session
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
## LOGGING
logging_obj.pre_call(
input=input,
@ -933,20 +741,33 @@ class AzureChatCompletion(BaseLLM):
)
if aembedding is True:
return self.aembedding( # type: ignore
return self.aembedding(
data=data,
input=input,
model=model,
logging_obj=logging_obj,
api_key=api_key,
model_response=model_response,
azure_client_params=azure_client_params,
timeout=timeout,
client=client,
litellm_params=litellm_params,
api_base=api_base,
)
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=False,
client=client,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
## COMPLETION CALL
raw_response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
headers = dict(raw_response.headers)
@ -1281,6 +1102,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider: Optional[Callable] = None,
client=None,
aimg_generation=None,
litellm_params: Optional[dict] = None,
) -> ImageResponse:
try:
if model and len(model) > 0:
@ -1305,25 +1127,14 @@ class AzureChatCompletion(BaseLLM):
)
# init AzureOpenAI Client
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
model_name=model or "",
api_version=api_version,
api_base=api_base,
is_async=False,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if aimg_generation is True:
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
@ -1386,6 +1197,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider: Optional[Callable] = None,
aspeech: Optional[bool] = None,
client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
max_retries = optional_params.pop("max_retries", 2)
@ -1404,19 +1216,17 @@ class AzureChatCompletion(BaseLLM):
max_retries=max_retries,
timeout=timeout,
client=client,
litellm_params=litellm_params,
) # type: ignore
azure_client: AzureOpenAI = self._get_sync_azure_client(
azure_client: AzureOpenAI = self.get_azure_openai_client(
api_base=api_base,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model=model,
max_retries=max_retries,
timeout=timeout,
_is_async=False,
client=client,
client_type="sync",
litellm_params=litellm_params,
) # type: ignore
response = azure_client.audio.speech.create(
@ -1441,19 +1251,17 @@ class AzureChatCompletion(BaseLLM):
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
api_base=api_base,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model=model,
max_retries=max_retries,
timeout=timeout,
_is_async=True,
client=client,
client_type="async",
litellm_params=litellm_params,
) # type: ignore
azure_response = await azure_client.audio.speech.create(

View file

@ -6,7 +6,6 @@ from typing import Any, Coroutine, Optional, Union, cast
import httpx
import litellm
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
from litellm.types.llms.openai import (
Batch,
@ -16,8 +15,10 @@ from litellm.types.llms.openai import (
)
from litellm.types.utils import LiteLLMBatch
from ..common_utils import BaseAzureLLM
class AzureBatchesAPI:
class AzureBatchesAPI(BaseAzureLLM):
"""
Azure methods to support for batches
- create_batch()
@ -29,38 +30,6 @@ class AzureBatchesAPI:
def __init__(self) -> None:
super().__init__()
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
@ -79,16 +48,16 @@ class AzureBatchesAPI:
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
)
if azure_client is None:
@ -125,16 +94,16 @@ class AzureBatchesAPI:
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
)
if azure_client is None:
@ -173,16 +142,16 @@ class AzureBatchesAPI:
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
)
if azure_client is None:
@ -212,16 +181,16 @@ class AzureBatchesAPI:
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
)
if azure_client is None:

View file

@ -99,6 +99,8 @@ class AzureOpenAIConfig(BaseConfig):
"extra_headers",
"parallel_tool_calls",
"prediction",
"modalities",
"audio",
]
def _is_response_format_supported_model(self, model: str) -> bool:

View file

@ -4,50 +4,69 @@ Handler file for calls to Azure OpenAI's o1/o3 family of models
Written separately to handle faking streaming for o1 and o3 models.
"""
from typing import Optional, Union
from typing import Any, Callable, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.types.utils import ModelResponse
from ...openai.openai import OpenAIChatCompletion
from ..common_utils import get_azure_openai_client
from ..common_utils import BaseAzureLLM
class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion):
def _get_openai_client(
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
def completion(
self,
is_async: bool,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
litellm_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
dynamic_params: Optional[bool] = None,
azure_ad_token: Optional[str] = None,
acompletion: bool = False,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Optional[
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:
# Override to use Azure-specific client initialization
if not isinstance(client, AzureOpenAI) and not isinstance(
client, AsyncAzureOpenAI
):
client = None
return get_azure_openai_client(
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
client = self.get_azure_openai_client(
litellm_params=litellm_params,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=is_async,
_is_async=acompletion,
)
return super().completion(
model_response=model_response,
timeout=timeout,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
model=model,
messages=messages,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
api_version=api_version,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
acompletion=acompletion,
logger_fn=logger_fn,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
client=client,
organization=organization,
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)

View file

@ -1,13 +1,22 @@
from typing import Callable, Optional, Union
import json
import os
from typing import Any, Callable, Dict, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI
import litellm
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.common_utils import BaseOpenAILLM
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
from litellm.secret_managers.main import get_secret_str
azure_ad_cache = DualCache()
class AzureOpenAIError(BaseLLMException):
def __init__(
@ -29,39 +38,6 @@ class AzureOpenAIError(BaseLLMException):
)
def get_azure_openai_client(
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "x-ratelimit-limit-requests" in headers:
@ -180,3 +156,271 @@ def get_azure_ad_token_from_username_password(
verbose_logger.debug("token_provider %s", token_provider)
return token_provider
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv(
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
)
if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret_str(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
azure_ad_token_cache_key = json.dumps(
{
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
}
)
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
return azure_ad_token_access_token
client = litellm.module_level_client
req_token = client.post(
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
azure_ad_token_json = req_token.json()
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
if azure_ad_token_access_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token access_token not returned"
)
if azure_ad_token_expires_in is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token expires_in not returned"
)
azure_ad_cache.set_cache(
key=azure_ad_token_cache_key,
value=azure_ad_token_access_token,
ttl=azure_ad_token_expires_in,
)
return azure_ad_token_access_token
def select_azure_base_url_or_endpoint(azure_client_params: dict):
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
class BaseAzureLLM(BaseOpenAILLM):
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
_is_async: bool = False,
model: Optional[str] = None,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
client_initialization_params: dict = locals()
if client is None:
cached_client = self.get_cached_openai_client(
client_initialization_params=client_initialization_params,
client_type="azure",
)
if cached_client:
if isinstance(cached_client, AzureOpenAI) or isinstance(
cached_client, AsyncAzureOpenAI
):
return cached_client
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name=model,
api_version=api_version,
is_async=_is_async,
)
if _is_async is True:
openai_client = AsyncAzureOpenAI(**azure_client_params)
else:
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
openai_client = client
if api_version is not None and isinstance(
openai_client._custom_query, dict
):
# set api_version to version passed by user
openai_client._custom_query.setdefault("api-version", api_version)
# save client in-memory cache
self.set_cached_openai_client(
openai_client=openai_client,
client_initialization_params=client_initialization_params,
client_type="azure",
)
return openai_client
def initialize_azure_sdk_client(
self,
litellm_params: dict,
api_key: Optional[str],
api_base: Optional[str],
model_name: Optional[str],
api_version: Optional[str],
is_async: bool,
) -> dict:
azure_ad_token_provider: Optional[Callable[[], str]] = None
# If we have api_key, then we have higher priority
azure_ad_token = litellm_params.get("azure_ad_token")
tenant_id = litellm_params.get("tenant_id")
client_id = litellm_params.get("client_id")
client_secret = litellm_params.get("client_secret")
azure_username = litellm_params.get("azure_username")
azure_password = litellm_params.get("azure_password")
max_retries = litellm_params.get("max_retries")
timeout = litellm_params.get("timeout")
if not api_key and tenant_id and client_id and client_secret:
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
)
if azure_username and azure_password and client_id:
azure_ad_token_provider = get_azure_ad_token_from_username_password(
azure_username=azure_username,
azure_password=azure_password,
client_id=client_id,
)
if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
elif (
not api_key
and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
try:
azure_ad_token_provider = get_azure_ad_token_provider()
except ValueError:
verbose_logger.debug("Azure AD Token Provider could not be used.")
if api_version is None:
api_version = os.getenv(
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
)
_api_key = api_key
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider,
}
# init http client + SSL Verification settings
if is_async is True:
azure_client_params["http_client"] = self._get_async_http_client()
else:
azure_client_params["http_client"] = self._get_sync_http_client()
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if timeout is not None:
azure_client_params["timeout"] = timeout
if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
return azure_client_params
def _init_azure_client_for_cloudflare_ai_gateway(
self,
api_base: str,
model: str,
api_version: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
api_key: Optional[str],
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable[[], str]],
acompletion: bool,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[AzureOpenAI, AsyncAzureOpenAI]:
## build base url - assume api base includes resource name
if client is None:
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
else:
client = AzureOpenAI(**azure_client_params) # type: ignore
return client

View file

@ -2,30 +2,16 @@ from typing import Any, Callable, Optional
from openai import AsyncAzureOpenAI, AzureOpenAI
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
from litellm.utils import CustomStreamWrapper, ModelResponse, TextCompletionResponse
from ...base import BaseLLM
from ...openai.completion.transformation import OpenAITextCompletionConfig
from ..common_utils import AzureOpenAIError
from ..common_utils import AzureOpenAIError, BaseAzureLLM
openai_text_completion_config = OpenAITextCompletionConfig()
def select_azure_base_url_or_endpoint(azure_client_params: dict):
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
class AzureTextCompletion(BaseLLM):
class AzureTextCompletion(BaseAzureLLM):
def __init__(self) -> None:
super().__init__()
@ -60,7 +46,6 @@ class AzureTextCompletion(BaseLLM):
headers: Optional[dict] = None,
client=None,
):
super().completion()
try:
if model is None or messages is None:
raise AzureOpenAIError(
@ -76,27 +61,18 @@ class AzureTextCompletion(BaseLLM):
### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if client is None:
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
azure_client_params = {
"api_version": api_version,
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
else:
client = AzureOpenAI(**azure_client_params)
client = self._init_azure_client_for_cloudflare_ai_gateway(
api_key=api_key,
api_version=api_version,
api_base=api_base,
model=model,
client=client,
max_retries=max_retries,
timeout=timeout,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
acompletion=acompletion,
)
data = {"model": None, "prompt": prompt, **optional_params}
else:
@ -118,6 +94,7 @@ class AzureTextCompletion(BaseLLM):
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
litellm_params=litellm_params,
)
else:
return self.acompletion(
@ -132,6 +109,7 @@ class AzureTextCompletion(BaseLLM):
client=client,
logging_obj=logging_obj,
max_retries=max_retries,
litellm_params=litellm_params,
)
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
@ -165,33 +143,21 @@ class AzureTextCompletion(BaseLLM):
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
"azure_ad_token_provider": azure_ad_token_provider,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
litellm_params=litellm_params,
_is_async=False,
model=model,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault(
"api-version", api_version
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
raw_response = azure_client.completions.with_raw_response.create(
**data, timeout=timeout
@ -240,36 +206,27 @@ class AzureTextCompletion(BaseLLM):
max_retries: int,
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
litellm_params: dict = {},
):
response = None
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=True,
client=client,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AsyncAzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AsyncAzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
@ -312,6 +269,7 @@ class AzureTextCompletion(BaseLLM):
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
litellm_params: dict = {},
):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
@ -319,28 +277,21 @@ class AzureTextCompletion(BaseLLM):
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=False,
client=client,
litellm_params=litellm_params,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
@ -375,33 +326,24 @@ class AzureTextCompletion(BaseLLM):
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
litellm_params: dict = {},
):
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": data.pop("max_retries", 2),
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=True,
client=client,
litellm_params=litellm_params,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
if not isinstance(azure_client, AsyncAzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AsyncAzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],

View file

@ -5,13 +5,12 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.file_deleted import FileDeleted
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import *
from ..common_utils import get_azure_openai_client
from ..common_utils import BaseAzureLLM
class AzureOpenAIFilesAPI(BaseLLM):
class AzureOpenAIFilesAPI(BaseAzureLLM):
"""
AzureOpenAI methods to support for batches
- create_file()
@ -45,14 +44,15 @@ class AzureOpenAIFilesAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
@ -91,17 +91,16 @@ class AzureOpenAIFilesAPI(BaseLLM):
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
organization=None,
client=client,
_is_async=_is_async,
)
@ -144,14 +143,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=None,
api_version=api_version,
client=client,
_is_async=_is_async,
@ -197,14 +195,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
organization: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,
@ -252,14 +249,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
purpose: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=None, # openai param
api_version=api_version,
client=client,
_is_async=_is_async,

View file

@ -3,11 +3,11 @@ from typing import Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.llms.azure.files.handler import get_azure_openai_client
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
"""
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
"""
@ -24,6 +24,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
] = None,
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[
Union[
OpenAI,
@ -36,12 +37,10 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None
return get_azure_openai_client(
return self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,

View file

@ -16,10 +16,23 @@ from litellm.llms.openai.openai import OpenAIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, ProviderField
from litellm.utils import _add_path_to_api_base
from litellm.utils import _add_path_to_api_base, supports_tool_choice
class AzureAIStudioConfig(OpenAIConfig):
def get_supported_openai_params(self, model: str) -> List:
model_supports_tool_choice = True # azure ai supports this by default
if not supports_tool_choice(model=f"azure_ai/{model}"):
model_supports_tool_choice = False
supported_params = super().get_supported_openai_params(model)
if not model_supports_tool_choice:
filtered_supported_params = []
for param in supported_params:
if param != "tool_choice":
filtered_supported_params.append(param)
return filtered_supported_params
return supported_params
def validate_environment(
self,
headers: dict,
@ -54,6 +67,7 @@ class AzureAIStudioConfig(OpenAIConfig):
api_base: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
@ -79,12 +93,14 @@ class AzureAIStudioConfig(OpenAIConfig):
original_url = httpx.URL(api_base)
# Extract api_version or use default
api_version = cast(Optional[str], optional_params.get("api_version"))
api_version = cast(Optional[str], litellm_params.get("api_version"))
# Check if 'api-version' is already present
if "api-version" not in original_url.params and api_version:
# Add api_version to optional_params
original_url.params["api-version"] = api_version
# Create a new dictionary with existing params
query_params = dict(original_url.params)
# Add api_version if needed
if "api-version" not in query_params and api_version:
query_params["api-version"] = api_version
# Add the path to the base URL
if "services.ai.azure.com" in api_base:
@ -96,8 +112,7 @@ class AzureAIStudioConfig(OpenAIConfig):
api_base=api_base, ending_path="/chat/completions"
)
# Convert optional_params to query parameters
query_params = original_url.params
# Use the new query_params dictionary
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)

Some files were not shown because too many files have changed in this diff Show more