Merge branch 'main' into custom_validation_docs
|
@ -49,7 +49,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.54.0
|
||||
pip install openai==1.66.1
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -71,7 +71,7 @@ jobs:
|
|||
pip install "Pillow==10.3.0"
|
||||
pip install "jsonschema==4.22.0"
|
||||
pip install "pytest-xdist==3.6.1"
|
||||
pip install "websockets==10.4"
|
||||
pip install "websockets==13.1.0"
|
||||
pip uninstall posthog -y
|
||||
- save_cache:
|
||||
paths:
|
||||
|
@ -168,7 +168,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.54.0
|
||||
pip install openai==1.66.1
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -189,6 +189,7 @@ jobs:
|
|||
pip install "diskcache==5.6.1"
|
||||
pip install "Pillow==10.3.0"
|
||||
pip install "jsonschema==4.22.0"
|
||||
pip install "websockets==13.1.0"
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
@ -267,7 +268,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.54.0
|
||||
pip install openai==1.66.1
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -288,6 +289,7 @@ jobs:
|
|||
pip install "diskcache==5.6.1"
|
||||
pip install "Pillow==10.3.0"
|
||||
pip install "jsonschema==4.22.0"
|
||||
pip install "websockets==13.1.0"
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
@ -511,7 +513,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.54.0
|
||||
pip install openai==1.66.1
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -678,6 +680,48 @@ jobs:
|
|||
paths:
|
||||
- llm_translation_coverage.xml
|
||||
- llm_translation_coverage
|
||||
llm_responses_api_testing:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
username: ${DOCKERHUB_USERNAME}
|
||||
password: ${DOCKERHUB_PASSWORD}
|
||||
working_directory: ~/project
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-cov==5.0.0"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "respx==0.21.1"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/llm_responses_api_testing --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout: 120m
|
||||
- run:
|
||||
name: Rename the coverage files
|
||||
command: |
|
||||
mv coverage.xml llm_responses_api_coverage.xml
|
||||
mv .coverage llm_responses_api_coverage
|
||||
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- persist_to_workspace:
|
||||
root: .
|
||||
paths:
|
||||
- llm_responses_api_coverage.xml
|
||||
- llm_responses_api_coverage
|
||||
litellm_mapped_tests:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
|
@ -1234,7 +1278,7 @@ jobs:
|
|||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "openai==1.54.0 "
|
||||
pip install "openai==1.66.1"
|
||||
- run:
|
||||
name: Install Grype
|
||||
command: |
|
||||
|
@ -1309,13 +1353,13 @@ jobs:
|
|||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
|
||||
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
|
||||
no_output_timeout: 120m
|
||||
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
e2e_openai_misc_endpoints:
|
||||
e2e_openai_endpoints:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
resource_class: xlarge
|
||||
|
@ -1370,7 +1414,7 @@ jobs:
|
|||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "openai==1.54.0 "
|
||||
pip install "openai==1.66.1"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Build Docker image
|
||||
|
@ -1432,7 +1476,7 @@ jobs:
|
|||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -s -vv tests/openai_misc_endpoints_tests --junitxml=test-results/junit.xml --durations=5
|
||||
python -m pytest -s -vv tests/openai_endpoints_tests --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout: 120m
|
||||
|
||||
# Store test results
|
||||
|
@ -1492,7 +1536,7 @@ jobs:
|
|||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "openai==1.54.0 "
|
||||
pip install "openai==1.66.1"
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||
|
@ -1921,7 +1965,7 @@ jobs:
|
|||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
pip install aiohttp
|
||||
pip install "openai==1.54.0 "
|
||||
pip install "openai==1.66.1"
|
||||
pip install "assemblyai==0.37.0"
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pydantic==2.7.1"
|
||||
|
@ -2068,7 +2112,7 @@ jobs:
|
|||
python -m venv venv
|
||||
. venv/bin/activate
|
||||
pip install coverage
|
||||
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage batches_coverage litellm_proxy_security_tests_coverage
|
||||
coverage combine llm_translation_coverage llm_responses_api_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage batches_coverage litellm_proxy_security_tests_coverage
|
||||
coverage xml
|
||||
- codecov/upload:
|
||||
file: ./coverage.xml
|
||||
|
@ -2197,7 +2241,7 @@ jobs:
|
|||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install aiohttp
|
||||
pip install "openai==1.54.0 "
|
||||
pip install "openai==1.66.1"
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "pytest==7.3.1"
|
||||
|
@ -2387,7 +2431,7 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- e2e_openai_misc_endpoints:
|
||||
- e2e_openai_endpoints:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
|
@ -2429,6 +2473,12 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- llm_responses_api_testing:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- litellm_mapped_tests:
|
||||
filters:
|
||||
branches:
|
||||
|
@ -2468,6 +2518,7 @@ workflows:
|
|||
- upload-coverage:
|
||||
requires:
|
||||
- llm_translation_testing
|
||||
- llm_responses_api_testing
|
||||
- litellm_mapped_tests
|
||||
- batches_testing
|
||||
- litellm_utils_testing
|
||||
|
@ -2522,10 +2573,11 @@ workflows:
|
|||
requires:
|
||||
- local_testing
|
||||
- build_and_test
|
||||
- e2e_openai_misc_endpoints
|
||||
- e2e_openai_endpoints
|
||||
- load_testing
|
||||
- test_bad_database_url
|
||||
- llm_translation_testing
|
||||
- llm_responses_api_testing
|
||||
- litellm_mapped_tests
|
||||
- batches_testing
|
||||
- litellm_utils_testing
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# used by CI/CD testing
|
||||
openai==1.54.0
|
||||
openai==1.66.1
|
||||
python-dotenv
|
||||
tiktoken
|
||||
importlib_metadata
|
||||
|
|
27
.github/workflows/ghcr_deploy.yml
vendored
|
@ -80,7 +80,6 @@ jobs:
|
|||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
#
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
@ -112,7 +111,11 @@ jobs:
|
|||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
||||
tags: |
|
||||
${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }},
|
||||
${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm:main-{1}', env.REGISTRY, github.event.inputs.tag) || '' }},
|
||||
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm:main-stable', env.REGISTRY) || '' }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
||||
|
||||
|
@ -151,8 +154,12 @@ jobs:
|
|||
context: .
|
||||
file: ./docker/Dockerfile.database
|
||||
push: true
|
||||
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
labels: ${{ steps.meta-database.outputs.labels }}
|
||||
tags: |
|
||||
${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }},
|
||||
${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-database:main-{1}', env.REGISTRY, github.event.inputs.tag) || '' }},
|
||||
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-database:main-stable', env.REGISTRY) || '' }}
|
||||
labels: ${{ steps.meta-database.outputs.labels }}
|
||||
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
||||
|
||||
build-and-push-image-non_root:
|
||||
|
@ -190,7 +197,11 @@ jobs:
|
|||
context: .
|
||||
file: ./docker/Dockerfile.non_root
|
||||
push: true
|
||||
tags: ${{ steps.meta-non_root.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-non_root.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
tags: |
|
||||
${{ steps.meta-non_root.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }},
|
||||
${{ steps.meta-non_root.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-non_root:main-{1}', env.REGISTRY, github.event.inputs.tag) || '' }},
|
||||
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-non_root:main-stable', env.REGISTRY) || '' }}
|
||||
labels: ${{ steps.meta-non_root.outputs.labels }}
|
||||
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
||||
|
||||
|
@ -229,7 +240,11 @@ jobs:
|
|||
context: .
|
||||
file: ./litellm-js/spend-logs/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
tags: |
|
||||
${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }},
|
||||
${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-spend_logs:main-{1}', env.REGISTRY, github.event.inputs.tag) || '' }},
|
||||
${{ github.event.inputs.release_type == 'stable' && format('{0}/berriai/litellm-spend_logs:main-stable', env.REGISTRY) || '' }}
|
||||
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
||||
|
||||
build-and-push-helm-chart:
|
||||
|
|
4
.gitignore
vendored
|
@ -79,3 +79,7 @@ litellm/proxy/_experimental/out/model_hub.html
|
|||
litellm/proxy/application.log
|
||||
tests/llm_translation/vertex_test_account.json
|
||||
tests/llm_translation/test_vertex_key.json
|
||||
litellm/proxy/migrations/0_init/migration.sql
|
||||
litellm/proxy/db/migrations/0_init/migration.sql
|
||||
litellm/proxy/db/migrations/*
|
||||
litellm/proxy/migrations/*
|
15
Makefile
|
@ -1,7 +1,7 @@
|
|||
# LiteLLM Makefile
|
||||
# Simple Makefile for running tests and basic development tasks
|
||||
|
||||
.PHONY: help test test-unit test-integration
|
||||
.PHONY: help test test-unit test-integration lint format
|
||||
|
||||
# Default target
|
||||
help:
|
||||
|
@ -9,6 +9,14 @@ help:
|
|||
@echo " make test - Run all tests"
|
||||
@echo " make test-unit - Run unit tests"
|
||||
@echo " make test-integration - Run integration tests"
|
||||
@echo " make test-unit-helm - Run helm unit tests"
|
||||
|
||||
install-dev:
|
||||
poetry install --with dev
|
||||
|
||||
lint: install-dev
|
||||
poetry run pip install types-requests types-setuptools types-redis types-PyYAML
|
||||
cd litellm && poetry run mypy . --ignore-missing-imports
|
||||
|
||||
# Testing
|
||||
test:
|
||||
|
@ -18,4 +26,7 @@ test-unit:
|
|||
poetry run pytest tests/litellm/
|
||||
|
||||
test-integration:
|
||||
poetry run pytest tests/ -k "not litellm"
|
||||
poetry run pytest tests/ -k "not litellm"
|
||||
|
||||
test-unit-helm:
|
||||
helm unittest -f 'tests/*.yaml' deploy/charts/litellm-helm
|
|
@ -18,7 +18,7 @@ type: application
|
|||
# This is the chart version. This version number should be incremented each time you make changes
|
||||
# to the chart and its templates, including the app version.
|
||||
# Versions are expected to follow Semantic Versioning (https://semver.org/)
|
||||
version: 0.4.1
|
||||
version: 0.4.2
|
||||
|
||||
# This is the version number of the application being deployed. This version number should be
|
||||
# incremented each time you make changes to the application. Versions are not expected to
|
||||
|
|
|
@ -22,6 +22,8 @@ If `db.useStackgresOperator` is used (not yet implemented):
|
|||
| Name | Description | Value |
|
||||
| ---------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----- |
|
||||
| `replicaCount` | The number of LiteLLM Proxy pods to be deployed | `1` |
|
||||
| `masterkeySecretName` | The name of the Kubernetes Secret that contains the Master API Key for LiteLLM. If not specified, use the generated secret name. | N/A |
|
||||
| `masterkeySecretKey` | The key within the Kubernetes Secret that contains the Master API Key for LiteLLM. If not specified, use `masterkey` as the key. | N/A |
|
||||
| `masterkey` | The Master API Key for LiteLLM. If not specified, a random key is generated. | N/A |
|
||||
| `environmentSecrets` | An optional array of Secret object names. The keys and values in these secrets will be presented to the LiteLLM proxy pod as environment variables. See below for an example Secret object. | `[]` |
|
||||
| `environmentConfigMaps` | An optional array of ConfigMap object names. The keys and values in these configmaps will be presented to the LiteLLM proxy pod as environment variables. See below for an example Secret object. | `[]` |
|
||||
|
|
|
@ -78,8 +78,8 @@ spec:
|
|||
- name: PROXY_MASTER_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: {{ include "litellm.fullname" . }}-masterkey
|
||||
key: masterkey
|
||||
name: {{ .Values.masterkeySecretName | default (printf "%s-masterkey" (include "litellm.fullname" .)) }}
|
||||
key: {{ .Values.masterkeySecretKey | default "masterkey" }}
|
||||
{{- if .Values.redis.enabled }}
|
||||
- name: REDIS_HOST
|
||||
value: {{ include "litellm.redis.serviceName" . }}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
{{- if not .Values.masterkeySecretName }}
|
||||
{{ $masterkey := (.Values.masterkey | default (randAlphaNum 17)) }}
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
|
@ -5,4 +6,5 @@ metadata:
|
|||
name: {{ include "litellm.fullname" . }}-masterkey
|
||||
data:
|
||||
masterkey: {{ $masterkey | b64enc }}
|
||||
type: Opaque
|
||||
type: Opaque
|
||||
{{- end }}
|
||||
|
|
|
@ -52,3 +52,31 @@ tests:
|
|||
- equal:
|
||||
path: spec.template.spec.affinity.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution.nodeSelectorTerms[0].matchExpressions[0].values[0]
|
||||
value: antarctica-east1
|
||||
- it: should work without masterkeySecretName or masterkeySecretKey
|
||||
template: deployment.yaml
|
||||
set:
|
||||
masterkeySecretName: ""
|
||||
masterkeySecretKey: ""
|
||||
asserts:
|
||||
- contains:
|
||||
path: spec.template.spec.containers[0].env
|
||||
content:
|
||||
name: PROXY_MASTER_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: RELEASE-NAME-litellm-masterkey
|
||||
key: masterkey
|
||||
- it: should work with masterkeySecretName and masterkeySecretKey
|
||||
template: deployment.yaml
|
||||
set:
|
||||
masterkeySecretName: my-secret
|
||||
masterkeySecretKey: my-key
|
||||
asserts:
|
||||
- contains:
|
||||
path: spec.template.spec.containers[0].env
|
||||
content:
|
||||
name: PROXY_MASTER_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: my-secret
|
||||
key: my-key
|
||||
|
|
18
deploy/charts/litellm-helm/tests/masterkey-secret_tests.yaml
Normal file
|
@ -0,0 +1,18 @@
|
|||
suite: test masterkey secret
|
||||
templates:
|
||||
- secret-masterkey.yaml
|
||||
tests:
|
||||
- it: should create a secret if masterkeySecretName is not set
|
||||
template: secret-masterkey.yaml
|
||||
set:
|
||||
masterkeySecretName: ""
|
||||
asserts:
|
||||
- isKind:
|
||||
of: Secret
|
||||
- it: should not create a secret if masterkeySecretName is set
|
||||
template: secret-masterkey.yaml
|
||||
set:
|
||||
masterkeySecretName: my-secret
|
||||
asserts:
|
||||
- hasDocuments:
|
||||
count: 0
|
|
@ -75,6 +75,12 @@ ingress:
|
|||
|
||||
# masterkey: changeit
|
||||
|
||||
# if set, use this secret for the master key; otherwise, autogenerate a new one
|
||||
masterkeySecretName: ""
|
||||
|
||||
# if set, use this secret key for the master key; otherwise, use the default key
|
||||
masterkeySecretKey: ""
|
||||
|
||||
# The elements within proxy_config are rendered as config.yaml for the proxy
|
||||
# Examples: https://github.com/BerriAI/litellm/tree/main/litellm/proxy/example_config_yaml
|
||||
# Reference: https://docs.litellm.ai/docs/proxy/configs
|
||||
|
|
|
@ -20,10 +20,18 @@ services:
|
|||
STORE_MODEL_IN_DB: "True" # allows adding models to proxy via UI
|
||||
env_file:
|
||||
- .env # Load local .env file
|
||||
depends_on:
|
||||
- db # Indicates that this service depends on the 'db' service, ensuring 'db' starts first
|
||||
healthcheck: # Defines the health check configuration for the container
|
||||
test: [ "CMD", "curl", "-f", "http://localhost:4000/health/liveliness || exit 1" ] # Command to execute for health check
|
||||
interval: 30s # Perform health check every 30 seconds
|
||||
timeout: 10s # Health check command times out after 10 seconds
|
||||
retries: 3 # Retry up to 3 times if health check fails
|
||||
start_period: 40s # Wait 40 seconds after container start before beginning health checks
|
||||
|
||||
|
||||
db:
|
||||
image: postgres
|
||||
image: postgres:16
|
||||
restart: always
|
||||
environment:
|
||||
POSTGRES_DB: litellm
|
||||
|
@ -31,6 +39,8 @@ services:
|
|||
POSTGRES_PASSWORD: dbpassword9090
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data # Persists Postgres data across container restarts
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -d litellm -U llmproxy"]
|
||||
interval: 1s
|
||||
|
@ -53,6 +63,8 @@ services:
|
|||
volumes:
|
||||
prometheus_data:
|
||||
driver: local
|
||||
postgres_data:
|
||||
name: litellm_postgres_data # Named volume for Postgres data persistence
|
||||
|
||||
|
||||
# ...rest of your docker-compose config if any
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [BETA] `/v1/messages`
|
||||
# /v1/messages [BETA]
|
||||
|
||||
LiteLLM provides a BETA endpoint in the spec of Anthropic's `/v1/messages` endpoint.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Assistants API
|
||||
# /assistants
|
||||
|
||||
Covers Threads, Messages, Assistants.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [BETA] Batches API
|
||||
# /batches
|
||||
|
||||
Covers Batches, Files
|
||||
|
||||
|
|
|
@ -3,7 +3,13 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
# Prompt Caching
|
||||
|
||||
For OpenAI + Anthropic + Deepseek, LiteLLM follows the OpenAI prompt caching usage object format:
|
||||
Supported Providers:
|
||||
- OpenAI (`openai/`)
|
||||
- Anthropic API (`anthropic/`)
|
||||
- Bedrock (`bedrock/`, `bedrock/invoke/`, `bedrock/converse`) ([All models bedrock supports prompt caching on](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html))
|
||||
- Deepseek API (`deepseek/`)
|
||||
|
||||
For the supported providers, LiteLLM follows the OpenAI prompt caching usage object format:
|
||||
|
||||
```bash
|
||||
"usage": {
|
||||
|
@ -499,4 +505,4 @@ curl -L -X GET 'http://0.0.0.0:4000/v1/model/info' \
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
This checks our maintained [model info/cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||
This checks our maintained [model info/cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Embeddings
|
||||
# /embeddings
|
||||
|
||||
## Quick Start
|
||||
```python
|
||||
|
|
|
@ -34,9 +34,9 @@ You can use our cloud product where we setup a dedicated instance for you.
|
|||
|
||||
Professional Support can assist with LLM/Provider integrations, deployment, upgrade management, and LLM Provider troubleshooting. We can’t solve your own infrastructure-related issues but we will guide you to fix them.
|
||||
|
||||
- 1 hour for Sev0 issues
|
||||
- 6 hours for Sev1
|
||||
- 24h for Sev2-Sev3 between 7am – 7pm PT (Monday through Saturday)
|
||||
- 1 hour for Sev0 issues - 100% production traffic is failing
|
||||
- 6 hours for Sev1 - <100% production traffic is failing
|
||||
- 24h for Sev2-Sev3 between 7am – 7pm PT (Monday through Saturday) - setup issues e.g. Redis working on our end, but not on your infrastructure.
|
||||
- 72h SLA for patching vulnerabilities in the software.
|
||||
|
||||
**We can offer custom SLAs** based on your needs and the severity of the issue
|
||||
|
|
|
@ -8,7 +8,7 @@ Here are the core requirements for any PR submitted to LiteLLM
|
|||
- [ ] Add testing, **Adding at least 1 test is a hard requirement** - [see details](#2-adding-testing-to-your-pr)
|
||||
- [ ] Ensure your PR passes the following tests:
|
||||
- [ ] [Unit Tests](#3-running-unit-tests)
|
||||
- [ ] Formatting / Linting Tests
|
||||
- [ ] [Formatting / Linting Tests](#35-running-linting-tests)
|
||||
- [ ] Keep scope as isolated as possible. As a general rule, your changes should address 1 specific problem at a time
|
||||
|
||||
|
||||
|
@ -56,6 +56,16 @@ run the following command on the root of the litellm directory
|
|||
make test-unit
|
||||
```
|
||||
|
||||
## 3.5 Running Linting Tests
|
||||
|
||||
run the following command on the root of the litellm directory
|
||||
|
||||
```shell
|
||||
make lint
|
||||
```
|
||||
|
||||
LiteLLM uses mypy for linting. On ci/cd we also run `black` for formatting.
|
||||
|
||||
## 4. Submit a PR with your changes!
|
||||
|
||||
- push your fork to your GitHub repo
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import TabItem from '@theme/TabItem';
|
||||
import Tabs from '@theme/Tabs';
|
||||
|
||||
# Files API
|
||||
# /files
|
||||
|
||||
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [Beta] Fine-tuning API
|
||||
# /fine_tuning
|
||||
|
||||
|
||||
:::info
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Moderation
|
||||
# /moderations
|
||||
|
||||
|
||||
### Usage
|
||||
|
|
|
@ -79,6 +79,7 @@ aws_session_name: Optional[str],
|
|||
aws_profile_name: Optional[str],
|
||||
aws_role_name: Optional[str],
|
||||
aws_web_identity_token: Optional[str],
|
||||
aws_bedrock_runtime_endpoint: Optional[str],
|
||||
```
|
||||
|
||||
### 2. Start the proxy
|
||||
|
@ -1262,6 +1263,473 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
</Tabs>
|
||||
|
||||
|
||||
## Bedrock Imported Models (Deepseek, Deepseek R1)
|
||||
|
||||
### Deepseek R1
|
||||
|
||||
This is a separate route, as the chat template is different.
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Provider Route | `bedrock/deepseek_r1/{model_arn}` |
|
||||
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
response = completion(
|
||||
model="bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/deepseek_r1/{your-model-arn}
|
||||
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
|
||||
**1. Add to config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: DeepSeek-R1-Distill-Llama-70B
|
||||
litellm_params:
|
||||
model: bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
|
||||
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING at http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Deepseek (not R1)
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Provider Route | `bedrock/llama/{model_arn}` |
|
||||
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
|
||||
|
||||
|
||||
|
||||
Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
response = completion(
|
||||
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn}
|
||||
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
|
||||
**1. Add to config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: DeepSeek-R1-Distill-Llama-70B
|
||||
litellm_params:
|
||||
model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
|
||||
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING at http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
## Provisioned throughput models
|
||||
To use provisioned throughput Bedrock models pass
|
||||
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
|
||||
- `model_id=provisioned-model-arn`
|
||||
|
||||
Completion
|
||||
```python
|
||||
import litellm
|
||||
response = litellm.completion(
|
||||
model="bedrock/anthropic.claude-instant-v1",
|
||||
model_id="provisioned-model-arn",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}]
|
||||
)
|
||||
```
|
||||
|
||||
Embedding
|
||||
```python
|
||||
import litellm
|
||||
response = litellm.embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
model_id="provisioned-model-arn",
|
||||
input=["hi"],
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Supported AWS Bedrock Models
|
||||
Here's an example of using a bedrock model with LiteLLM. For a complete list, refer to the [model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||
|
||||
| Model Name | Command |
|
||||
|----------------------------|------------------------------------------------------------------|
|
||||
| Anthropic Claude-V3.5 Sonnet | `completion(model='bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V3 sonnet | `completion(model='bedrock/anthropic.claude-3-sonnet-20240229-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V3 Haiku | `completion(model='bedrock/anthropic.claude-3-haiku-20240307-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V3 Opus | `completion(model='bedrock/anthropic.claude-3-opus-20240229-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Amazon Titan Express | `completion(model='bedrock/amazon.titan-text-express-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Cohere Command | `completion(model='bedrock/cohere.command-text-v14', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| AI21 J2-Mid | `completion(model='bedrock/ai21.j2-mid-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| AI21 J2-Ultra | `completion(model='bedrock/ai21.j2-ultra-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| AI21 Jamba-Instruct | `completion(model='bedrock/ai21.jamba-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 Chat 13b | `completion(model='bedrock/meta.llama2-13b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 Chat 70b | `completion(model='bedrock/meta.llama2-70b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Mistral 7B Instruct | `completion(model='bedrock/mistral.mistral-7b-instruct-v0:2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Mixtral 8x7B Instruct | `completion(model='bedrock/mistral.mixtral-8x7b-instruct-v0:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
|
||||
## Bedrock Embedding
|
||||
|
||||
### API keys
|
||||
This can be set as env variables or passed as **params to litellm.embedding()**
|
||||
```python
|
||||
import os
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "" # Access key
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "" # Secret access key
|
||||
os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
|
||||
```
|
||||
|
||||
### Usage
|
||||
```python
|
||||
from litellm import embedding
|
||||
response = embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Supported AWS Bedrock Embedding Models
|
||||
|
||||
| Model Name | Usage | Supported Additional OpenAI params |
|
||||
|----------------------|---------------------------------------------|-----|
|
||||
| Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py#L59) |
|
||||
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py#L53)
|
||||
| Titan Multimodal Embeddings | `embedding(model="bedrock/amazon.titan-embed-image-v1", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py#L28) |
|
||||
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/cohere_transformation.py#L18)
|
||||
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/cohere_transformation.py#L18)
|
||||
|
||||
### Advanced - [Drop Unsupported Params](https://docs.litellm.ai/docs/completion/drop_params#openai-proxy-usage)
|
||||
|
||||
### Advanced - [Pass model/provider-specific Params](https://docs.litellm.ai/docs/completion/provider_specific_params#proxy-usage)
|
||||
|
||||
## Image Generation
|
||||
Use this for stable diffusion, and amazon nova canvas on bedrock
|
||||
|
||||
|
||||
### Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import image_generation
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = image_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="bedrock/stability.stable-diffusion-xl-v0",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
```
|
||||
|
||||
**Set optional params**
|
||||
```python
|
||||
import os
|
||||
from litellm import image_generation
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = image_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="bedrock/stability.stable-diffusion-xl-v0",
|
||||
### OPENAI-COMPATIBLE ###
|
||||
size="128x512", # width=128, height=512
|
||||
### PROVIDER-SPECIFIC ### see `AmazonStabilityConfig` in bedrock.py for all params
|
||||
seed=30
|
||||
)
|
||||
print(f"response: {response}")
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: amazon.nova-canvas-v1:0
|
||||
litellm_params:
|
||||
model: bedrock/amazon.nova-canvas-v1:0
|
||||
aws_region_name: "us-east-1"
|
||||
aws_secret_access_key: my-key # OPTIONAL - all boto3 auth params supported
|
||||
aws_secret_access_id: my-id # OPTIONAL - all boto3 auth params supported
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/images/generations' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer $LITELLM_VIRTUAL_KEY' \
|
||||
-d '{
|
||||
"model": "amazon.nova-canvas-v1:0",
|
||||
"prompt": "A cute baby sea otter"
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Supported AWS Bedrock Image Generation Models
|
||||
|
||||
| Model Name | Function Call |
|
||||
|----------------------|---------------------------------------------|
|
||||
| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` |
|
||||
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` |
|
||||
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` |
|
||||
|
||||
|
||||
## Rerank API
|
||||
|
||||
Use Bedrock's Rerank API in the Cohere `/rerank` format.
|
||||
|
||||
Supported Cohere Rerank Params
|
||||
- `model` - the foundation model ARN
|
||||
- `query` - the query to rerank against
|
||||
- `documents` - the list of documents to rerank
|
||||
- `top_n` - the number of results to return
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="SDK" value="sdk">
|
||||
|
||||
```python
|
||||
from litellm import rerank
|
||||
import os
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = rerank(
|
||||
model="bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", # provide the model ARN - get this here https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=2,
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem label="PROXY" value="proxy">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bedrock-rerank
|
||||
litellm_params:
|
||||
model: bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
```
|
||||
|
||||
2. Start proxy server
|
||||
|
||||
```bash
|
||||
litellm --config config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/rerank \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "bedrock-rerank",
|
||||
"query": "What is the capital of the United States?",
|
||||
"documents": [
|
||||
"Carson City is the capital city of the American state of Nevada.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||
"Washington, D.C. is the capital of the United States.",
|
||||
"Capital punishment has existed in the United States since before it was a country."
|
||||
],
|
||||
"top_n": 3
|
||||
|
||||
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Bedrock Application Inference Profile
|
||||
|
||||
Use Bedrock Application Inference Profile to track costs for projects on AWS.
|
||||
|
||||
You can either pass it in the model name - `model="bedrock/arn:...` or as a separate `model_id="arn:..` param.
|
||||
|
||||
### Set via `model_id`
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="SDK" value="sdk">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = completion(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
model_id="arn:aws:bedrock:eu-central-1:000000000000:application-inference-profile/a0a0a0a0a0a0",
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem label="PROXY" value="proxy">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: anthropic-claude-3-5-sonnet
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
# You have to set the ARN application inference profile in the model_id parameter
|
||||
model_id: arn:aws:bedrock:eu-central-1:000000000000:application-inference-profile/a0a0a0a0a0a0
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer $LITELLM_API_KEY' \
|
||||
-d '{
|
||||
"model": "anthropic-claude-3-5-sonnet",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "List 5 important events in the XIX century"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Boto3 - Authentication
|
||||
|
||||
### Passing credentials as parameters - Completion()
|
||||
|
@ -1497,7 +1965,7 @@ response = completion(
|
|||
aws_bedrock_client=bedrock,
|
||||
)
|
||||
```
|
||||
## Calling via Internal Proxy
|
||||
## Calling via Internal Proxy (not bedrock url compatible)
|
||||
|
||||
Use the `bedrock/converse_like/model` endpoint to call bedrock converse model via your internal proxy.
|
||||
|
||||
|
@ -1563,359 +2031,3 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
```bash
|
||||
https://some-api-url/models
|
||||
```
|
||||
|
||||
## Bedrock Imported Models (Deepseek, Deepseek R1)
|
||||
|
||||
### Deepseek R1
|
||||
|
||||
This is a separate route, as the chat template is different.
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Provider Route | `bedrock/deepseek_r1/{model_arn}` |
|
||||
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
response = completion(
|
||||
model="bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/deepseek_r1/{your-model-arn}
|
||||
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
|
||||
**1. Add to config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: DeepSeek-R1-Distill-Llama-70B
|
||||
litellm_params:
|
||||
model: bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
|
||||
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING at http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Deepseek (not R1)
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Provider Route | `bedrock/llama/{model_arn}` |
|
||||
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
|
||||
|
||||
|
||||
|
||||
Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
response = completion(
|
||||
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn}
|
||||
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
|
||||
**1. Add to config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: DeepSeek-R1-Distill-Llama-70B
|
||||
litellm_params:
|
||||
model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
|
||||
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING at http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
## Provisioned throughput models
|
||||
To use provisioned throughput Bedrock models pass
|
||||
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
|
||||
- `model_id=provisioned-model-arn`
|
||||
|
||||
Completion
|
||||
```python
|
||||
import litellm
|
||||
response = litellm.completion(
|
||||
model="bedrock/anthropic.claude-instant-v1",
|
||||
model_id="provisioned-model-arn",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}]
|
||||
)
|
||||
```
|
||||
|
||||
Embedding
|
||||
```python
|
||||
import litellm
|
||||
response = litellm.embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
model_id="provisioned-model-arn",
|
||||
input=["hi"],
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Supported AWS Bedrock Models
|
||||
Here's an example of using a bedrock model with LiteLLM. For a complete list, refer to the [model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||
|
||||
| Model Name | Command |
|
||||
|----------------------------|------------------------------------------------------------------|
|
||||
| Anthropic Claude-V3.5 Sonnet | `completion(model='bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V3 sonnet | `completion(model='bedrock/anthropic.claude-3-sonnet-20240229-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V3 Haiku | `completion(model='bedrock/anthropic.claude-3-haiku-20240307-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V3 Opus | `completion(model='bedrock/anthropic.claude-3-opus-20240229-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Amazon Titan Express | `completion(model='bedrock/amazon.titan-text-express-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Cohere Command | `completion(model='bedrock/cohere.command-text-v14', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| AI21 J2-Mid | `completion(model='bedrock/ai21.j2-mid-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| AI21 J2-Ultra | `completion(model='bedrock/ai21.j2-ultra-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| AI21 Jamba-Instruct | `completion(model='bedrock/ai21.jamba-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 Chat 13b | `completion(model='bedrock/meta.llama2-13b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Meta Llama 2 Chat 70b | `completion(model='bedrock/meta.llama2-70b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Mistral 7B Instruct | `completion(model='bedrock/mistral.mistral-7b-instruct-v0:2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
| Mixtral 8x7B Instruct | `completion(model='bedrock/mistral.mixtral-8x7b-instruct-v0:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
|
||||
## Bedrock Embedding
|
||||
|
||||
### API keys
|
||||
This can be set as env variables or passed as **params to litellm.embedding()**
|
||||
```python
|
||||
import os
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "" # Access key
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "" # Secret access key
|
||||
os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
|
||||
```
|
||||
|
||||
### Usage
|
||||
```python
|
||||
from litellm import embedding
|
||||
response = embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Supported AWS Bedrock Embedding Models
|
||||
|
||||
| Model Name | Usage | Supported Additional OpenAI params |
|
||||
|----------------------|---------------------------------------------|-----|
|
||||
| Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py#L59) |
|
||||
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py#L53)
|
||||
| Titan Multimodal Embeddings | `embedding(model="bedrock/amazon.titan-embed-image-v1", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py#L28) |
|
||||
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/cohere_transformation.py#L18)
|
||||
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` | [here](https://github.com/BerriAI/litellm/blob/f5905e100068e7a4d61441d7453d7cf5609c2121/litellm/llms/bedrock/embed/cohere_transformation.py#L18)
|
||||
|
||||
### Advanced - [Drop Unsupported Params](https://docs.litellm.ai/docs/completion/drop_params#openai-proxy-usage)
|
||||
|
||||
### Advanced - [Pass model/provider-specific Params](https://docs.litellm.ai/docs/completion/provider_specific_params#proxy-usage)
|
||||
|
||||
## Image Generation
|
||||
Use this for stable diffusion on bedrock
|
||||
|
||||
|
||||
### Usage
|
||||
```python
|
||||
import os
|
||||
from litellm import image_generation
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = image_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="bedrock/stability.stable-diffusion-xl-v0",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
```
|
||||
|
||||
**Set optional params**
|
||||
```python
|
||||
import os
|
||||
from litellm import image_generation
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = image_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="bedrock/stability.stable-diffusion-xl-v0",
|
||||
### OPENAI-COMPATIBLE ###
|
||||
size="128x512", # width=128, height=512
|
||||
### PROVIDER-SPECIFIC ### see `AmazonStabilityConfig` in bedrock.py for all params
|
||||
seed=30
|
||||
)
|
||||
print(f"response: {response}")
|
||||
```
|
||||
|
||||
## Supported AWS Bedrock Image Generation Models
|
||||
|
||||
| Model Name | Function Call |
|
||||
|----------------------|---------------------------------------------|
|
||||
| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` |
|
||||
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` |
|
||||
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` |
|
||||
|
||||
|
||||
## Rerank API
|
||||
|
||||
Use Bedrock's Rerank API in the Cohere `/rerank` format.
|
||||
|
||||
Supported Cohere Rerank Params
|
||||
- `model` - the foundation model ARN
|
||||
- `query` - the query to rerank against
|
||||
- `documents` - the list of documents to rerank
|
||||
- `top_n` - the number of results to return
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="SDK" value="sdk">
|
||||
|
||||
```python
|
||||
from litellm import rerank
|
||||
import os
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = rerank(
|
||||
model="bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", # provide the model ARN - get this here https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=2,
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem label="PROXY" value="proxy">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bedrock-rerank
|
||||
litellm_params:
|
||||
model: bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
```
|
||||
|
||||
2. Start proxy server
|
||||
|
||||
```bash
|
||||
litellm --config config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/rerank \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "bedrock-rerank",
|
||||
"query": "What is the capital of the United States?",
|
||||
"documents": [
|
||||
"Carson City is the capital city of the American state of Nevada.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||
"Washington, D.C. is the capital of the United States.",
|
||||
"Capital punishment has existed in the United States since before it was a country."
|
||||
],
|
||||
"top_n": 3
|
||||
|
||||
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ messages = [{ "content": "Hello, how are you?","role": "user"}]
|
|||
# litellm proxy call
|
||||
response = completion(
|
||||
model="litellm_proxy/your-model-name",
|
||||
messages,
|
||||
messages=messages,
|
||||
api_base = "your-litellm-proxy-url",
|
||||
api_key = "your-litellm-proxy-api-key"
|
||||
)
|
||||
|
@ -76,7 +76,7 @@ messages = [{ "content": "Hello, how are you?","role": "user"}]
|
|||
# openai call
|
||||
response = completion(
|
||||
model="litellm_proxy/your-model-name",
|
||||
messages,
|
||||
messages=messages,
|
||||
api_base = "your-litellm-proxy-url",
|
||||
stream=True
|
||||
)
|
||||
|
|
90
docs/my-website/docs/providers/snowflake.md
Normal file
|
@ -0,0 +1,90 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
|
||||
# Snowflake
|
||||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | The Snowflake Cortex LLM REST API lets you access the COMPLETE function via HTTP POST requests|
|
||||
| Provider Route on LiteLLM | `snowflake/` |
|
||||
| Link to Provider Doc | [Snowflake ↗](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api) |
|
||||
| Base URL | [https://{account-id}.snowflakecomputing.com/api/v2/cortex/inference:complete/](https://{account-id}.snowflakecomputing.com/api/v2/cortex/inference:complete) |
|
||||
| Supported OpenAI Endpoints | `/chat/completions`, `/completions` |
|
||||
|
||||
|
||||
|
||||
Currently, Snowflake's REST API does not have an endpoint for `snowflake-arctic-embed` embedding models. If you want to use these embedding models with Litellm, you can call them through our Hugging Face provider.
|
||||
|
||||
Find the Arctic Embed models [here](https://huggingface.co/collections/Snowflake/arctic-embed-661fd57d50fab5fc314e4c18) on Hugging Face.
|
||||
|
||||
## Supported OpenAI Parameters
|
||||
```
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"response_format"
|
||||
```
|
||||
|
||||
## API KEYS
|
||||
|
||||
Snowflake does have API keys. Instead, you access the Snowflake API with your JWT token and account identifier.
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["SNOWFLAKE_JWT"] = "YOUR JWT"
|
||||
os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT IDENTIFIER"
|
||||
```
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
## set ENV variables
|
||||
os.environ["SNOWFLAKE_JWT"] = "YOUR JWT"
|
||||
os.environ["SNOWFLAKE_ACCOUNT_ID"] = "YOUR ACCOUNT IDENTIFIER"
|
||||
|
||||
# Snowflake call
|
||||
response = completion(
|
||||
model="snowflake/mistral-7b",
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Usage with LiteLLM Proxy
|
||||
|
||||
#### 1. Required env variables
|
||||
```bash
|
||||
export SNOWFLAKE_JWT=""
|
||||
export SNOWFLAKE_ACCOUNT_ID = ""
|
||||
```
|
||||
|
||||
#### 2. Start the proxy~
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: mistral-7b
|
||||
litellm_params:
|
||||
model: snowflake/mistral-7b
|
||||
api_key: YOUR_API_KEY
|
||||
api_base: https://YOUR-ACCOUNT-ID.snowflakecomputing.com/api/v2/cortex/inference:complete
|
||||
|
||||
```
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
#### 3. Test it
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "snowflake/mistral-7b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
}
|
||||
]
|
||||
}
|
||||
'
|
||||
```
|
|
@ -10,17 +10,13 @@ Role-based access control (RBAC) is based on Organizations, Teams and Internal U
|
|||
|
||||
## Roles
|
||||
|
||||
**Admin Roles**
|
||||
- `proxy_admin`: admin over the platform
|
||||
- `proxy_admin_viewer`: can login, view all keys, view all spend. **Cannot** create keys/delete keys/add new users
|
||||
|
||||
**Organization Roles**
|
||||
- `org_admin`: admin over the organization. Can create teams and users within their organization
|
||||
|
||||
**Internal User Roles**
|
||||
- `internal_user`: can login, view/create/delete their own keys, view their spend. **Cannot** add new users.
|
||||
- `internal_user_viewer`: can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users.
|
||||
|
||||
| Role Type | Role Name | Permissions |
|
||||
|-----------|-----------|-------------|
|
||||
| **Admin** | `proxy_admin` | Admin over the platform |
|
||||
| | `proxy_admin_viewer` | Can login, view all keys, view all spend. **Cannot** create keys/delete keys/add new users |
|
||||
| **Organization** | `org_admin` | Admin over the organization. Can create teams and users within their organization |
|
||||
| **Internal User** | `internal_user` | Can login, view/create/delete their own keys, view their spend. **Cannot** add new users |
|
||||
| | `internal_user_viewer` | Can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users |
|
||||
|
||||
## Onboarding Organizations
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ general_settings:
|
|||
| use_x_forwarded_for | str | If true, uses the X-Forwarded-For header to get the client IP address |
|
||||
| service_account_settings | List[Dict[str, Any]] | Set `service_account_settings` if you want to create settings that only apply to service account keys (Doc on service accounts)[./service_accounts.md] |
|
||||
| image_generation_model | str | The default model to use for image generation - ignores model set in request |
|
||||
| store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) |
|
||||
| store_model_in_db | boolean | If true, enables storing model + credential information in the DB. |
|
||||
| store_prompts_in_spend_logs | boolean | If true, allows prompts and responses to be stored in the spend logs table. |
|
||||
| max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. |
|
||||
| max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. |
|
||||
|
@ -499,9 +499,11 @@ router_settings:
|
|||
| SMTP_USERNAME | Username for SMTP authentication (do not set if SMTP does not require auth)
|
||||
| SPEND_LOGS_URL | URL for retrieving spend logs
|
||||
| SSL_CERTIFICATE | Path to the SSL certificate file
|
||||
| SSL_SECURITY_LEVEL | [BETA] Security level for SSL/TLS connections. E.g. `DEFAULT@SECLEVEL=1`
|
||||
| SSL_VERIFY | Flag to enable or disable SSL certificate verification
|
||||
| SUPABASE_KEY | API key for Supabase service
|
||||
| SUPABASE_URL | Base URL for Supabase instance
|
||||
| STORE_MODEL_IN_DB | If true, enables storing model + credential information in the DB.
|
||||
| TEST_EMAIL_ADDRESS | Email address used for testing purposes
|
||||
| UI_LOGO_PATH | Path to the logo image used in the UI
|
||||
| UI_PASSWORD | Password for accessing the UI
|
||||
|
@ -513,4 +515,3 @@ router_settings:
|
|||
| UPSTREAM_LANGFUSE_SECRET_KEY | Secret key for upstream Langfuse authentication
|
||||
| USE_AWS_KMS | Flag to enable AWS Key Management Service for encryption
|
||||
| WEBHOOK_URL | URL for receiving webhooks from external services
|
||||
|
||||
|
|
|
@ -448,6 +448,34 @@ model_list:
|
|||
|
||||
s/o to [@David Manouchehri](https://www.linkedin.com/in/davidmanouchehri/) for helping with this.
|
||||
|
||||
### Centralized Credential Management
|
||||
|
||||
Define credentials once and reuse them across multiple models. This helps with:
|
||||
- Secret rotation
|
||||
- Reducing config duplication
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: azure/gpt-4o
|
||||
litellm_credential_name: default_azure_credential # Reference credential below
|
||||
|
||||
credential_list:
|
||||
- credential_name: default_azure_credential
|
||||
credential_values:
|
||||
api_key: os.environ/AZURE_API_KEY # Load from environment
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_version: "2023-05-15"
|
||||
credential_info:
|
||||
description: "Production credentials for EU region"
|
||||
```
|
||||
|
||||
#### Key Parameters
|
||||
- `credential_name`: Unique identifier for the credential set
|
||||
- `credential_values`: Key-value pairs of credentials/secrets (supports `os.environ/` syntax)
|
||||
- `credential_info`: Key-value pairs of user provided credentials information. No key-value pairs are required, but the dictionary must exist.
|
||||
|
||||
### Load API Keys from Secret Managers (Azure Vault, etc)
|
||||
|
||||
[**Using Secret Managers with LiteLLM Proxy**](../secret)
|
||||
|
@ -641,4 +669,4 @@ docker run --name litellm-proxy \
|
|||
ghcr.io/berriai/litellm-database:main-latest
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</Tabs>
|
||||
|
|
194
docs/my-website/docs/proxy/custom_prompt_management.md
Normal file
|
@ -0,0 +1,194 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Custom Prompt Management
|
||||
|
||||
Connect LiteLLM to your prompt management system with custom hooks.
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
<Image
|
||||
img={require('../../img/custom_prompt_management.png')}
|
||||
style={{width: '100%', display: 'block', margin: '2rem auto'}}
|
||||
/>
|
||||
|
||||
|
||||
## How it works
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Create Your Custom Prompt Manager
|
||||
|
||||
Create a class that inherits from `CustomPromptManagement` to handle prompt retrieval and formatting:
|
||||
|
||||
**Example Implementation**
|
||||
|
||||
Create a new file called `custom_prompt.py` and add this code. The key method here is `get_chat_completion_prompt` you can implement custom logic to retrieve and format prompts based on the `prompt_id` and `prompt_variables`.
|
||||
|
||||
```python
|
||||
from typing import List, Tuple, Optional
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
class MyCustomPromptManagement(CustomPromptManagement):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Retrieve and format prompts based on prompt_id.
|
||||
|
||||
Returns:
|
||||
- model: The model to use
|
||||
- messages: The formatted messages
|
||||
- non_default_params: Optional parameters like temperature
|
||||
"""
|
||||
# Example matching the diagram: Add system message for prompt_id "1234"
|
||||
if prompt_id == "1234":
|
||||
# Prepend system message while preserving existing messages
|
||||
new_messages = [
|
||||
{"role": "system", "content": "Be a good Bot!"},
|
||||
] + messages
|
||||
return model, new_messages, non_default_params
|
||||
|
||||
# Default: Return original messages if no prompt_id match
|
||||
return model, messages, non_default_params
|
||||
|
||||
prompt_management = MyCustomPromptManagement()
|
||||
```
|
||||
|
||||
### 2. Configure Your Prompt Manager in LiteLLM `config.yaml`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/gpt-4
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
callbacks: custom_prompt.prompt_management # sets litellm.callbacks = [prompt_management]
|
||||
```
|
||||
|
||||
### 3. Start LiteLLM Gateway
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="docker" label="Docker Run">
|
||||
|
||||
Mount your `custom_logger.py` on the LiteLLM Docker container.
|
||||
|
||||
```shell
|
||||
docker run -d \
|
||||
-p 4000:4000 \
|
||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||
--name my-app \
|
||||
-v $(pwd)/my_config.yaml:/app/config.yaml \
|
||||
-v $(pwd)/custom_logger.py:/app/custom_logger.py \
|
||||
my-app:latest \
|
||||
--config /app/config.yaml \
|
||||
--port 4000 \
|
||||
--detailed_debug \
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="py" label="litellm pip">
|
||||
|
||||
```shell
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### 4. Test Your Custom Prompt Manager
|
||||
|
||||
When you pass `prompt_id="1234"`, the custom prompt manager will add a system message "Be a good Bot!" to your conversation:
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="sk-1234",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gemini-1.5-pro",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
prompt_id="1234"
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="langchain" label="Langchain">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
chat = ChatOpenAI(
|
||||
model="gpt-4",
|
||||
openai_api_key="sk-1234",
|
||||
openai_api_base="http://0.0.0.0:4000",
|
||||
extra_body={
|
||||
"prompt_id": "1234"
|
||||
}
|
||||
)
|
||||
|
||||
messages = []
|
||||
response = chat(messages)
|
||||
|
||||
print(response.content)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```shell
|
||||
curl -X POST http://0.0.0.0:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "gemini-1.5-pro",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"prompt_id": "1234"
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
The request will be transformed from:
|
||||
```json
|
||||
{
|
||||
"model": "gemini-1.5-pro",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"prompt_id": "1234"
|
||||
}
|
||||
```
|
||||
|
||||
To:
|
||||
```json
|
||||
{
|
||||
"model": "gemini-1.5-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Be a good Bot!"},
|
||||
{"role": "user", "content": "hi"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ guardrails:
|
|||
- guardrail_name: aim-protected-app
|
||||
litellm_params:
|
||||
guardrail: aim
|
||||
mode: pre_call # 'during_call' is also available
|
||||
mode: [pre_call, post_call] # "During_call" is also available
|
||||
api_key: os.environ/AIM_API_KEY
|
||||
api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost
|
||||
```
|
||||
|
|
|
@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [BETA] Prompt Management
|
||||
# Prompt Management
|
||||
|
||||
:::info
|
||||
|
||||
|
@ -12,9 +12,10 @@ This feature is currently in beta, and might change unexpectedly. We expect this
|
|||
|
||||
Run experiments or change the specific model (e.g. from gpt-4o to gpt4o-mini finetune) from your prompt management tool (e.g. Langfuse) instead of making changes in the application.
|
||||
|
||||
Supported Integrations:
|
||||
- [Langfuse](https://langfuse.com/docs/prompts/get-started)
|
||||
- [Humanloop](../observability/humanloop)
|
||||
| Supported Integrations | Link |
|
||||
|------------------------|------|
|
||||
| Langfuse | [Get Started](https://langfuse.com/docs/prompts/get-started) |
|
||||
| Humanloop | [Get Started](../observability/humanloop) |
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
|
|
@ -102,7 +102,19 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Advanced - Set Accepted JWT Scope Names
|
||||
## Advanced
|
||||
|
||||
### Multiple OIDC providers
|
||||
|
||||
Use this if you want LiteLLM to validate your JWT against multiple OIDC providers (e.g. Google Cloud, GitHub Auth)
|
||||
|
||||
Set `JWT_PUBLIC_KEY_URL` in your environment to a comma-separated list of URLs for your OIDC providers.
|
||||
|
||||
```bash
|
||||
export JWT_PUBLIC_KEY_URL="https://demo.duendesoftware.com/.well-known/openid-configuration/jwks,https://accounts.google.com/.well-known/openid-configuration/jwks"
|
||||
```
|
||||
|
||||
### Set Accepted JWT Scope Names
|
||||
|
||||
Change the string in JWT 'scopes', that litellm evaluates to see if a user has admin access.
|
||||
|
||||
|
@ -114,7 +126,7 @@ general_settings:
|
|||
admin_jwt_scope: "litellm-proxy-admin"
|
||||
```
|
||||
|
||||
## Tracking End-Users / Internal Users / Team / Org
|
||||
### Tracking End-Users / Internal Users / Team / Org
|
||||
|
||||
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
||||
|
||||
|
@ -156,7 +168,7 @@ scope: ["litellm-proxy-admin",...]
|
|||
scope: "litellm-proxy-admin ..."
|
||||
```
|
||||
|
||||
## Control model access with Teams
|
||||
### Control model access with Teams
|
||||
|
||||
|
||||
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
||||
|
@ -207,7 +219,7 @@ OIDC Auth for API: [**See Walkthrough**](https://www.loom.com/share/00fe2deab59a
|
|||
- If all checks pass, allow the request
|
||||
|
||||
|
||||
## Advanced - Custom Validate
|
||||
### Custom JWT Validate
|
||||
|
||||
This section allows you to add custom logic to intercept and perform validation of the JWT token.
|
||||
|
||||
|
@ -215,7 +227,7 @@ This can occur when there is additional logic that is needed to execute against
|
|||
|
||||
> _Note_: You can expect the JWT will have ran the typical decrypting of the public key, token decoding, and expiration time checks before executing the custom validation function.
|
||||
|
||||
### 1. Setup custom validate function
|
||||
#### 1. Setup custom validate function
|
||||
|
||||
```python
|
||||
from typing import Any, Literal
|
||||
|
@ -236,7 +248,7 @@ def my_custom_validate(token: dict[str, Any]) -> Literal[True]:
|
|||
return True
|
||||
```
|
||||
|
||||
### 2. Setup config.yaml
|
||||
#### 2. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
|
@ -249,7 +261,7 @@ general_settings:
|
|||
custom_validate: custom_validate.my_custom_validate # 👈 custom validate function
|
||||
```
|
||||
|
||||
### 3. Test the flow
|
||||
#### 3. Test the flow
|
||||
|
||||
**Expected JWT**
|
||||
|
||||
|
@ -271,7 +283,7 @@ general_settings:
|
|||
}
|
||||
```
|
||||
|
||||
## Advanced - Allowed Routes
|
||||
### Allowed Routes
|
||||
|
||||
Configure which routes a JWT can access via the config.
|
||||
|
||||
|
@ -303,7 +315,7 @@ general_settings:
|
|||
team_allowed_routes: ["/v1/chat/completions"] # 👈 Set accepted routes
|
||||
```
|
||||
|
||||
## Advanced - Caching Public Keys
|
||||
### Caching Public Keys
|
||||
|
||||
Control how long public keys are cached for (in seconds).
|
||||
|
||||
|
@ -317,7 +329,7 @@ general_settings:
|
|||
public_key_ttl: 600 # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
## Advanced - Custom JWT Field
|
||||
### Custom JWT Field
|
||||
|
||||
Set a custom field in which the team_id exists. By default, the 'client_id' field is checked.
|
||||
|
||||
|
@ -329,14 +341,7 @@ general_settings:
|
|||
team_id_jwt_field: "client_id" # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
## All Params
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/b204f0c01c703317d812a1553363ab0cb989d5b6/litellm/proxy/_types.py#L95)
|
||||
|
||||
|
||||
|
||||
|
||||
## Advanced - Block Teams
|
||||
### Block Teams
|
||||
|
||||
To block all requests for a certain team id, use `/team/block`
|
||||
|
||||
|
@ -363,7 +368,7 @@ curl --location 'http://0.0.0.0:4000/team/unblock' \
|
|||
```
|
||||
|
||||
|
||||
## Advanced - Upsert Users + Allowed Email Domains
|
||||
### Upsert Users + Allowed Email Domains
|
||||
|
||||
Allow users who belong to a specific email domain, automatic access to the proxy.
|
||||
|
||||
|
@ -501,3 +506,7 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
|||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## All JWT Params
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/b204f0c01c703317d812a1553363ab0cb989d5b6/litellm/proxy/_types.py#L95)
|
||||
|
|
55
docs/my-website/docs/proxy/ui_credentials.md
Normal file
|
@ -0,0 +1,55 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Adding LLM Credentials
|
||||
|
||||
You can add LLM provider credentials on the UI. Once you add credentials you can re-use them when adding new models
|
||||
|
||||
## Add a credential + model
|
||||
|
||||
### 1. Navigate to LLM Credentials page
|
||||
|
||||
Go to Models -> LLM Credentials -> Add Credential
|
||||
|
||||
<Image img={require('../../img/ui_cred_add.png')} />
|
||||
|
||||
### 2. Add credentials
|
||||
|
||||
Select your LLM provider, enter your API Key and click "Add Credential"
|
||||
|
||||
**Note: Credentials are based on the provider, if you select Vertex AI then you will see `Vertex Project`, `Vertex Location` and `Vertex Credentials` fields**
|
||||
|
||||
<Image img={require('../../img/ui_add_cred_2.png')} />
|
||||
|
||||
|
||||
### 3. Use credentials when adding a model
|
||||
|
||||
Go to Add Model -> Existing Credentials -> Select your credential in the dropdown
|
||||
|
||||
<Image img={require('../../img/ui_cred_3.png')} />
|
||||
|
||||
|
||||
## Create a Credential from an existing model
|
||||
|
||||
Use this if you have already created a model and want to store the model credentials for future use
|
||||
|
||||
### 1. Select model to create a credential from
|
||||
|
||||
Go to Models -> Select your model -> Credential -> Create Credential
|
||||
|
||||
<Image img={require('../../img/ui_cred_4.png')} />
|
||||
|
||||
### 2. Use new credential when adding a model
|
||||
|
||||
Go to Add Model -> Existing Credentials -> Select your credential in the dropdown
|
||||
|
||||
<Image img={require('../../img/use_model_cred.png')} />
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
||||
|
||||
How are credentials stored?
|
||||
Credentials in the DB are encrypted/decrypted using `LITELLM_SALT_KEY`, if set. If not, then they are encrypted using `LITELLM_MASTER_KEY`. These keys should be kept secret and not shared with others.
|
||||
|
||||
|
55
docs/my-website/docs/proxy/ui_logs.md
Normal file
|
@ -0,0 +1,55 @@
|
|||
|
||||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# UI Logs Page
|
||||
|
||||
View Spend, Token Usage, Key, Team Name for Each Request to LiteLLM
|
||||
|
||||
|
||||
<Image img={require('../../img/ui_request_logs.png')}/>
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
| Log Type | Tracked by Default |
|
||||
|----------|-------------------|
|
||||
| Success Logs | ✅ Yes |
|
||||
| Error Logs | ✅ Yes |
|
||||
| Request/Response Content Stored | ❌ No by Default, **opt in with `store_prompts_in_spend_logs`** |
|
||||
|
||||
|
||||
|
||||
**By default LiteLLM does not track the request and response content.**
|
||||
|
||||
## Tracking - Request / Response Content in Logs Page
|
||||
|
||||
If you want to view request and response content on LiteLLM Logs, you need to opt in with this setting
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
store_prompts_in_spend_logs: true
|
||||
```
|
||||
|
||||
<Image img={require('../../img/ui_request_logs_content.png')}/>
|
||||
|
||||
|
||||
## Stop storing Error Logs in DB
|
||||
|
||||
If you do not want to store error logs in DB, you can opt out with this setting
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_error_logs: True # Only disable writing error logs to DB, regular spend logs will still be written unless `disable_spend_logs: True`
|
||||
```
|
||||
|
||||
## Stop storing Spend Logs in DB
|
||||
|
||||
If you do not want to store spend logs in DB, you can opt out with this setting
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_spend_logs: True # Disable writing spend logs to DB
|
||||
```
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Realtime Endpoints
|
||||
# /realtime
|
||||
|
||||
Use this to loadbalance across Azure + OpenAI.
|
||||
|
||||
|
|
|
@ -3,11 +3,20 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
# 'Thinking' / 'Reasoning Content'
|
||||
|
||||
:::info
|
||||
|
||||
Requires LiteLLM v1.63.0+
|
||||
|
||||
:::
|
||||
|
||||
Supported Providers:
|
||||
- Deepseek (`deepseek/`)
|
||||
- Anthropic API (`anthropic/`)
|
||||
- Bedrock (Anthropic + Deepseek) (`bedrock/`)
|
||||
- Vertex AI (Anthropic) (`vertexai/`)
|
||||
- OpenRouter (`openrouter/`)
|
||||
|
||||
LiteLLM will standardize the `reasoning_content` in the response and `thinking_blocks` in the assistant message.
|
||||
|
||||
```python
|
||||
"message": {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Rerank
|
||||
# /rerank
|
||||
|
||||
:::tip
|
||||
|
||||
|
|
117
docs/my-website/docs/response_api.md
Normal file
|
@ -0,0 +1,117 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# /responses [Beta]
|
||||
|
||||
LiteLLM provides a BETA endpoint in the spec of [OpenAI's `/responses` API](https://platform.openai.com/docs/api-reference/responses)
|
||||
|
||||
| Feature | Supported | Notes |
|
||||
|---------|-----------|--------|
|
||||
| Cost Tracking | ✅ | Works with all supported models |
|
||||
| Logging | ✅ | Works across all integrations |
|
||||
| End-user Tracking | ✅ | |
|
||||
| Streaming | ✅ | |
|
||||
| Fallbacks | ✅ | Works between supported models |
|
||||
| Loadbalancing | ✅ | Works between supported models |
|
||||
| Supported LiteLLM Versions | 1.63.8+ | |
|
||||
| Supported LLM providers | `openai` | |
|
||||
|
||||
## Usage
|
||||
|
||||
## Create a model response
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="litellm-sdk" label="LiteLLM SDK">
|
||||
|
||||
#### Non-streaming
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Non-streaming response
|
||||
response = litellm.responses(
|
||||
model="gpt-4o",
|
||||
input="Tell me a three sentence bedtime story about a unicorn.",
|
||||
max_output_tokens=100
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### Streaming
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Streaming response
|
||||
response = litellm.responses(
|
||||
model="gpt-4o",
|
||||
input="Tell me a three sentence bedtime story about a unicorn.",
|
||||
stream=True
|
||||
)
|
||||
|
||||
for event in response:
|
||||
print(event)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="OpenAI SDK with LiteLLM Proxy">
|
||||
|
||||
First, add this to your litellm proxy config.yaml:
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Start your LiteLLM proxy:
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
Then use the OpenAI SDK pointed to your proxy:
|
||||
|
||||
#### Non-streaming
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize client with your proxy URL
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:4000", # Your proxy URL
|
||||
api_key="your-api-key" # Your proxy API key
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
response = client.responses.create(
|
||||
model="gpt-4o",
|
||||
input="Tell me a three sentence bedtime story about a unicorn."
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### Streaming
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize client with your proxy URL
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:4000", # Your proxy URL
|
||||
api_key="your-api-key" # Your proxy API key
|
||||
)
|
||||
|
||||
# Streaming response
|
||||
response = client.responses.create(
|
||||
model="gpt-4o",
|
||||
input="Tell me a three sentence bedtime story about a unicorn.",
|
||||
stream=True
|
||||
)
|
||||
|
||||
for event in response:
|
||||
print(event)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
|
@ -830,7 +830,7 @@ asyncio.run(router_acompletion())
|
|||
|
||||
Set `weight` on a deployment to pick one deployment more often than others.
|
||||
|
||||
This works across **ALL** routing strategies.
|
||||
This works across **simple-shuffle** routing strategy (this is the default, if no routing strategy is selected).
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
|
|
@ -96,7 +96,7 @@ litellm --config /path/to/config.yaml
|
|||
```
|
||||
|
||||
|
||||
### Using K/V pairs in 1 AWS Secret
|
||||
#### Using K/V pairs in 1 AWS Secret
|
||||
|
||||
You can read multiple keys from a single AWS Secret using the `primary_secret_name` parameter:
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Text Completion
|
||||
# /completions
|
||||
|
||||
### Usage
|
||||
<Tabs>
|
||||
|
|
BIN
docs/my-website/img/custom_prompt_management.png
Normal file
After Width: | Height: | Size: 346 KiB |
BIN
docs/my-website/img/release_notes/credentials.jpg
Normal file
After Width: | Height: | Size: 371 KiB |
BIN
docs/my-website/img/release_notes/litellm_test_connection.gif
Normal file
After Width: | Height: | Size: 16 MiB |
BIN
docs/my-website/img/release_notes/responses_api.png
Normal file
After Width: | Height: | Size: 67 KiB |
BIN
docs/my-website/img/ui_add_cred_2.png
Normal file
After Width: | Height: | Size: 255 KiB |
BIN
docs/my-website/img/ui_cred_3.png
Normal file
After Width: | Height: | Size: 283 KiB |
BIN
docs/my-website/img/ui_cred_4.png
Normal file
After Width: | Height: | Size: 255 KiB |
BIN
docs/my-website/img/ui_cred_add.png
Normal file
After Width: | Height: | Size: 204 KiB |
BIN
docs/my-website/img/ui_request_logs.png
Normal file
After Width: | Height: | Size: 567 KiB |
BIN
docs/my-website/img/ui_request_logs_content.png
Normal file
After Width: | Height: | Size: 344 KiB |
BIN
docs/my-website/img/use_model_cred.png
Normal file
After Width: | Height: | Size: 282 KiB |
47
docs/my-website/package-lock.json
generated
|
@ -706,12 +706,13 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@babel/helpers": {
|
||||
"version": "7.26.0",
|
||||
"resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.26.0.tgz",
|
||||
"integrity": "sha512-tbhNuIxNcVb21pInl3ZSjksLCvgdZy9KwJ8brv993QtIVKJBBkYXz4q4ZbAv31GdnC+R90np23L5FbEBlthAEw==",
|
||||
"version": "7.26.10",
|
||||
"resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.26.10.tgz",
|
||||
"integrity": "sha512-UPYc3SauzZ3JGgj87GgZ89JVdC5dj0AoetR5Bw6wj4niittNyFh6+eOGonYvJ1ao6B8lEa3Q3klS7ADZ53bc5g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/template": "^7.25.9",
|
||||
"@babel/types": "^7.26.0"
|
||||
"@babel/template": "^7.26.9",
|
||||
"@babel/types": "^7.26.10"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6.9.0"
|
||||
|
@ -796,11 +797,12 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@babel/parser": {
|
||||
"version": "7.26.3",
|
||||
"resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.3.tgz",
|
||||
"integrity": "sha512-WJ/CvmY8Mea8iDXo6a7RK2wbmJITT5fN3BEkRuFlxVyNx8jOKIIhmC4fSkTcPcf8JyavbBwIe6OpiCOBXt/IcA==",
|
||||
"version": "7.26.10",
|
||||
"resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.10.tgz",
|
||||
"integrity": "sha512-6aQR2zGE/QFi8JpDLjUZEPYOs7+mhKXm86VaKFiLP35JQwQb6bwUE+XbvkH0EptsYhbNBSUGaUBLKqxH1xSgsA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/types": "^7.26.3"
|
||||
"@babel/types": "^7.26.10"
|
||||
},
|
||||
"bin": {
|
||||
"parser": "bin/babel-parser.js"
|
||||
|
@ -2157,9 +2159,10 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@babel/runtime-corejs3": {
|
||||
"version": "7.26.0",
|
||||
"resolved": "https://registry.npmjs.org/@babel/runtime-corejs3/-/runtime-corejs3-7.26.0.tgz",
|
||||
"integrity": "sha512-YXHu5lN8kJCb1LOb9PgV6pvak43X2h4HvRApcN5SdWeaItQOzfn1hgP6jasD6KWQyJDBxrVmA9o9OivlnNJK/w==",
|
||||
"version": "7.26.10",
|
||||
"resolved": "https://registry.npmjs.org/@babel/runtime-corejs3/-/runtime-corejs3-7.26.10.tgz",
|
||||
"integrity": "sha512-uITFQYO68pMEYR46AHgQoyBg7KPPJDAbGn4jUTIRgCFJIp88MIBUianVOplhZDEec07bp9zIyr4Kp0FCyQzmWg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"core-js-pure": "^3.30.2",
|
||||
"regenerator-runtime": "^0.14.0"
|
||||
|
@ -2169,13 +2172,14 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@babel/template": {
|
||||
"version": "7.25.9",
|
||||
"resolved": "https://registry.npmjs.org/@babel/template/-/template-7.25.9.tgz",
|
||||
"integrity": "sha512-9DGttpmPvIxBb/2uwpVo3dqJ+O6RooAFOS+lB+xDqoE2PVCE8nfoHMdZLpfCQRLwvohzXISPZcgxt80xLfsuwg==",
|
||||
"version": "7.26.9",
|
||||
"resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz",
|
||||
"integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/code-frame": "^7.25.9",
|
||||
"@babel/parser": "^7.25.9",
|
||||
"@babel/types": "^7.25.9"
|
||||
"@babel/code-frame": "^7.26.2",
|
||||
"@babel/parser": "^7.26.9",
|
||||
"@babel/types": "^7.26.9"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6.9.0"
|
||||
|
@ -2199,9 +2203,10 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@babel/types": {
|
||||
"version": "7.26.3",
|
||||
"resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.3.tgz",
|
||||
"integrity": "sha512-vN5p+1kl59GVKMvTHt55NzzmYVxprfJD+ql7U9NFIfKCBkYE55LYtS+WtPlaYOyzydrKI8Nezd+aZextrd+FMA==",
|
||||
"version": "7.26.10",
|
||||
"resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.10.tgz",
|
||||
"integrity": "sha512-emqcG3vHrpxUKTrxcblR36dcrcoRDvKmnL/dCL6ZsHaShW80qxCAcNhzQZrpeM765VzEos+xOi4s+r4IXzTwdQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/helper-string-parser": "^7.25.9",
|
||||
"@babel/helper-validator-identifier": "^7.25.9"
|
||||
|
|
180
docs/my-website/release_notes/v1.63.11-stable/index.md
Normal file
|
@ -0,0 +1,180 @@
|
|||
---
|
||||
title: v1.63.11-stable
|
||||
slug: v1.63.11-stable
|
||||
date: 2025-03-15T10:00:00
|
||||
authors:
|
||||
- name: Krrish Dholakia
|
||||
title: CEO, LiteLLM
|
||||
url: https://www.linkedin.com/in/krish-d/
|
||||
image_url: https://media.licdn.com/dms/image/v2/D4D03AQGrlsJ3aqpHmQ/profile-displayphoto-shrink_400_400/B4DZSAzgP7HYAg-/0/1737327772964?e=1743638400&v=beta&t=39KOXMUFedvukiWWVPHf3qI45fuQD7lNglICwN31DrI
|
||||
- name: Ishaan Jaffer
|
||||
title: CTO, LiteLLM
|
||||
url: https://www.linkedin.com/in/reffajnaahsi/
|
||||
image_url: https://pbs.twimg.com/profile_images/1613813310264340481/lz54oEiB_400x400.jpg
|
||||
|
||||
tags: [credential management, thinking content, responses api, snowflake]
|
||||
hide_table_of_contents: false
|
||||
---
|
||||
|
||||
import Image from '@theme/IdealImage';
|
||||
|
||||
These are the changes since `v1.63.2-stable`.
|
||||
|
||||
This release is primarily focused on:
|
||||
- [Beta] Responses API Support
|
||||
- Snowflake Cortex Support, Amazon Nova Image Generation
|
||||
- UI - Credential Management, re-use credentials when adding new models
|
||||
- UI - Test Connection to LLM Provider before adding a model
|
||||
|
||||
:::info
|
||||
|
||||
This release will be live on 03/16/2025
|
||||
|
||||
:::
|
||||
|
||||
<!-- <Image img={require('../../img/release_notes/v16311_release.jpg')} /> -->
|
||||
|
||||
## Known Issues
|
||||
- 🚨 Known issue on Azure OpenAI - We don't recommend upgrading if you use Azure OpenAI. This version failed our Azure OpenAI load test
|
||||
|
||||
|
||||
## Docker Run LiteLLM Proxy
|
||||
|
||||
```
|
||||
docker run
|
||||
-e STORE_MODEL_IN_DB=True
|
||||
-p 4000:4000
|
||||
ghcr.io/berriai/litellm:main-v1.63.11-stable
|
||||
```
|
||||
|
||||
## Demo Instance
|
||||
|
||||
Here's a Demo Instance to test changes:
|
||||
- Instance: https://demo.litellm.ai/
|
||||
- Login Credentials:
|
||||
- Username: admin
|
||||
- Password: sk-1234
|
||||
|
||||
|
||||
|
||||
## New Models / Updated Models
|
||||
|
||||
- Image Generation support for Amazon Nova Canvas [Getting Started](https://docs.litellm.ai/docs/providers/bedrock#image-generation)
|
||||
- Add pricing for Jamba new models [PR](https://github.com/BerriAI/litellm/pull/9032/files)
|
||||
- Add pricing for Amazon EU models [PR](https://github.com/BerriAI/litellm/pull/9056/files)
|
||||
- Add Bedrock Deepseek R1 model pricing [PR](https://github.com/BerriAI/litellm/pull/9108/files)
|
||||
- Update Gemini pricing: Gemma 3, Flash 2 thinking update, LearnLM [PR](https://github.com/BerriAI/litellm/pull/9190/files)
|
||||
- Mark Cohere Embedding 3 models as Multimodal [PR](https://github.com/BerriAI/litellm/pull/9176/commits/c9a576ce4221fc6e50dc47cdf64ab62736c9da41)
|
||||
- Add Azure Data Zone pricing [PR](https://github.com/BerriAI/litellm/pull/9185/files#diff-19ad91c53996e178c1921cbacadf6f3bae20cfe062bd03ee6bfffb72f847ee37)
|
||||
- LiteLLM Tracks cost for `azure/eu` and `azure/us` models
|
||||
|
||||
|
||||
|
||||
## LLM Translation
|
||||
|
||||
<Image img={require('../../img/release_notes/responses_api.png')} />
|
||||
|
||||
1. **New Endpoints**
|
||||
- [Beta] POST `/responses` API. [Getting Started](https://docs.litellm.ai/docs/response_api)
|
||||
|
||||
2. **New LLM Providers**
|
||||
- Snowflake Cortex [Getting Started](https://docs.litellm.ai/docs/providers/snowflake)
|
||||
|
||||
3. **New LLM Features**
|
||||
|
||||
- Support OpenRouter `reasoning_content` on streaming [Getting Started](https://docs.litellm.ai/docs/reasoning_content)
|
||||
|
||||
4. **Bug Fixes**
|
||||
|
||||
- OpenAI: Return `code`, `param` and `type` on bad request error [More information on litellm exceptions](https://docs.litellm.ai/docs/exception_mapping)
|
||||
- Bedrock: Fix converse chunk parsing to only return empty dict on tool use [PR](https://github.com/BerriAI/litellm/pull/9166)
|
||||
- Bedrock: Support extra_headers [PR](https://github.com/BerriAI/litellm/pull/9113)
|
||||
- Azure: Fix Function Calling Bug & Update Default API Version to `2025-02-01-preview` [PR](https://github.com/BerriAI/litellm/pull/9191)
|
||||
- Azure: Fix AI services URL [PR](https://github.com/BerriAI/litellm/pull/9185)
|
||||
- Vertex AI: Handle HTTP 201 status code in response [PR](https://github.com/BerriAI/litellm/pull/9193)
|
||||
- Perplexity: Fix incorrect streaming response [PR](https://github.com/BerriAI/litellm/pull/9081)
|
||||
- Triton: Fix streaming completions bug [PR](https://github.com/BerriAI/litellm/pull/8386)
|
||||
- Deepgram: Support bytes.IO when handling audio files for transcription [PR](https://github.com/BerriAI/litellm/pull/9071)
|
||||
- Ollama: Fix "system" role has become unacceptable [PR](https://github.com/BerriAI/litellm/pull/9261)
|
||||
- All Providers (Streaming): Fix String `data:` stripped from entire content in streamed responses [PR](https://github.com/BerriAI/litellm/pull/9070)
|
||||
|
||||
|
||||
|
||||
## Spend Tracking Improvements
|
||||
|
||||
1. Support Bedrock converse cache token tracking [Getting Started](https://docs.litellm.ai/docs/completion/prompt_caching)
|
||||
2. Cost Tracking for Responses API [Getting Started](https://docs.litellm.ai/docs/response_api)
|
||||
3. Fix Azure Whisper cost tracking [Getting Started](https://docs.litellm.ai/docs/audio_transcription)
|
||||
|
||||
|
||||
## UI
|
||||
|
||||
### Re-Use Credentials on UI
|
||||
|
||||
You can now onboard LLM provider credentials on LiteLLM UI. Once these credentials are added you can re-use them when adding new models [Getting Started](https://docs.litellm.ai/docs/proxy/ui_credentials)
|
||||
|
||||
<Image img={require('../../img/release_notes/credentials.jpg')} />
|
||||
|
||||
|
||||
### Test Connections before adding models
|
||||
|
||||
Before adding a model you can test the connection to the LLM provider to verify you have setup your API Base + API Key correctly
|
||||
|
||||
<Image img={require('../../img/release_notes/litellm_test_connection.gif')} />
|
||||
|
||||
### General UI Improvements
|
||||
1. Add Models Page
|
||||
- Allow adding Cerebras, Sambanova, Perplexity, Fireworks, Openrouter, TogetherAI Models, Text-Completion OpenAI on Admin UI
|
||||
- Allow adding EU OpenAI models
|
||||
- Fix: Instantly show edit + deletes to models
|
||||
2. Keys Page
|
||||
- Fix: Instantly show newly created keys on Admin UI (don't require refresh)
|
||||
- Fix: Allow clicking into Top Keys when showing users Top API Key
|
||||
- Fix: Allow Filter Keys by Team Alias, Key Alias and Org
|
||||
- UI Improvements: Show 100 Keys Per Page, Use full height, increase width of key alias
|
||||
3. Users Page
|
||||
- Fix: Show correct count of internal user keys on Users Page
|
||||
- Fix: Metadata not updating in Team UI
|
||||
4. Logs Page
|
||||
- UI Improvements: Keep expanded log in focus on LiteLLM UI
|
||||
- UI Improvements: Minor improvements to logs page
|
||||
- Fix: Allow internal user to query their own logs
|
||||
- Allow switching off storing Error Logs in DB [Getting Started](https://docs.litellm.ai/docs/proxy/ui_logs)
|
||||
5. Sign In/Sign Out
|
||||
- Fix: Correctly use `PROXY_LOGOUT_URL` when set [Getting Started](https://docs.litellm.ai/docs/proxy/self_serve#setting-custom-logout-urls)
|
||||
|
||||
|
||||
## Security
|
||||
|
||||
1. Support for Rotating Master Keys [Getting Started](https://docs.litellm.ai/docs/proxy/master_key_rotations)
|
||||
2. Fix: Internal User Viewer Permissions, don't allow `internal_user_viewer` role to see `Test Key Page` or `Create Key Button` [More information on role based access controls](https://docs.litellm.ai/docs/proxy/access_control)
|
||||
3. Emit audit logs on All user + model Create/Update/Delete endpoints [Getting Started](https://docs.litellm.ai/docs/proxy/multiple_admins)
|
||||
4. JWT
|
||||
- Support multiple JWT OIDC providers [Getting Started](https://docs.litellm.ai/docs/proxy/token_auth)
|
||||
- Fix JWT access with Groups not working when team is assigned All Proxy Models access
|
||||
5. Using K/V pairs in 1 AWS Secret [Getting Started](https://docs.litellm.ai/docs/secret#using-kv-pairs-in-1-aws-secret)
|
||||
|
||||
|
||||
## Logging Integrations
|
||||
|
||||
1. Prometheus: Track Azure LLM API latency metric [Getting Started](https://docs.litellm.ai/docs/proxy/prometheus#request-latency-metrics)
|
||||
2. Athina: Added tags, user_feedback and model_options to additional_keys which can be sent to Athina [Getting Started](https://docs.litellm.ai/docs/observability/athina_integration)
|
||||
|
||||
|
||||
## Performance / Reliability improvements
|
||||
|
||||
1. Redis + litellm router - Fix Redis cluster mode for litellm router [PR](https://github.com/BerriAI/litellm/pull/9010)
|
||||
|
||||
|
||||
## General Improvements
|
||||
|
||||
1. OpenWebUI Integration - display `thinking` tokens
|
||||
- Guide on getting started with LiteLLM x OpenWebUI. [Getting Started](https://docs.litellm.ai/docs/tutorials/openweb_ui)
|
||||
- Display `thinking` tokens on OpenWebUI (Bedrock, Anthropic, Deepseek) [Getting Started](https://docs.litellm.ai/docs/tutorials/openweb_ui#render-thinking-content-on-openweb-ui)
|
||||
|
||||
<Image img={require('../../img/litellm_thinking_openweb.gif')} />
|
||||
|
||||
|
||||
## Complete Git Diff
|
||||
|
||||
[Here's the complete git diff](https://github.com/BerriAI/litellm/compare/v1.63.2-stable...v1.63.11-stable)
|
|
@ -101,7 +101,9 @@ const sidebars = {
|
|||
"proxy/admin_ui_sso",
|
||||
"proxy/self_serve",
|
||||
"proxy/public_teams",
|
||||
"proxy/custom_sso"
|
||||
"proxy/custom_sso",
|
||||
"proxy/ui_credentials",
|
||||
"proxy/ui_logs"
|
||||
],
|
||||
},
|
||||
{
|
||||
|
@ -231,6 +233,7 @@ const sidebars = {
|
|||
"providers/sambanova",
|
||||
"providers/custom_llm_server",
|
||||
"providers/petals",
|
||||
"providers/snowflake"
|
||||
],
|
||||
},
|
||||
{
|
||||
|
@ -273,7 +276,7 @@ const sidebars = {
|
|||
items: [
|
||||
{
|
||||
type: "category",
|
||||
label: "Chat",
|
||||
label: "/chat/completions",
|
||||
link: {
|
||||
type: "generated-index",
|
||||
title: "Chat Completions",
|
||||
|
@ -286,12 +289,13 @@ const sidebars = {
|
|||
"completion/usage",
|
||||
],
|
||||
},
|
||||
"response_api",
|
||||
"text_completion",
|
||||
"embedding/supported_embedding",
|
||||
"anthropic_unified",
|
||||
{
|
||||
type: "category",
|
||||
label: "Image",
|
||||
label: "/images",
|
||||
items: [
|
||||
"image_generation",
|
||||
"image_variations",
|
||||
|
@ -299,7 +303,7 @@ const sidebars = {
|
|||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Audio",
|
||||
label: "/audio",
|
||||
"items": [
|
||||
"audio_transcription",
|
||||
"text_to_speech",
|
||||
|
@ -361,8 +365,12 @@ const sidebars = {
|
|||
],
|
||||
},
|
||||
{
|
||||
type: "doc",
|
||||
id: "proxy/prompt_management"
|
||||
type: "category",
|
||||
label: "[Beta] Prompt Management",
|
||||
items: [
|
||||
"proxy/prompt_management",
|
||||
"proxy/custom_prompt_management"
|
||||
],
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
|
|
|
@ -163,7 +163,7 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
|
||||
pass
|
||||
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -173,6 +173,7 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
|
|
|
@ -94,6 +94,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -107,6 +107,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -126,6 +126,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -31,7 +31,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
|
|||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -41,6 +41,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
text = ""
|
||||
|
|
|
@ -8,12 +8,14 @@ import os
|
|||
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.caching.caching import Cache, DualCache, RedisCache, InMemoryCache
|
||||
from litellm.caching.llm_caching_handler import LLMClientCache
|
||||
from litellm.types.llms.bedrock import COHERE_EMBEDDING_INPUT_TYPES
|
||||
from litellm.types.utils import (
|
||||
ImageObject,
|
||||
BudgetConfig,
|
||||
all_litellm_params,
|
||||
all_litellm_params as _litellm_completion_params,
|
||||
CredentialItem,
|
||||
) # maintain backwards compatibility for root param
|
||||
from litellm._logging import (
|
||||
set_verbose,
|
||||
|
@ -180,6 +182,7 @@ cloudflare_api_key: Optional[str] = None
|
|||
baseten_key: Optional[str] = None
|
||||
aleph_alpha_key: Optional[str] = None
|
||||
nlp_cloud_key: Optional[str] = None
|
||||
snowflake_key: Optional[str] = None
|
||||
common_cloud_provider_auth_params: dict = {
|
||||
"params": ["project", "region_name", "token"],
|
||||
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
|
||||
|
@ -189,15 +192,17 @@ ssl_verify: Union[str, bool] = True
|
|||
ssl_certificate: Optional[str] = None
|
||||
disable_streaming_logging: bool = False
|
||||
disable_add_transform_inline_image_block: bool = False
|
||||
in_memory_llm_clients_cache: InMemoryCache = InMemoryCache()
|
||||
in_memory_llm_clients_cache: LLMClientCache = LLMClientCache()
|
||||
safe_memory_mode: bool = False
|
||||
enable_azure_ad_token_refresh: Optional[bool] = False
|
||||
### DEFAULT AZURE API VERSION ###
|
||||
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
|
||||
AZURE_DEFAULT_API_VERSION = "2025-02-01-preview" # this is updated to the latest
|
||||
### DEFAULT WATSONX API VERSION ###
|
||||
WATSONX_DEFAULT_API_VERSION = "2024-03-13"
|
||||
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
||||
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
|
||||
### CREDENTIALS ###
|
||||
credential_list: List[CredentialItem] = []
|
||||
### GUARDRAILS ###
|
||||
llamaguard_model_name: Optional[str] = None
|
||||
openai_moderations_model_name: Optional[str] = None
|
||||
|
@ -412,6 +417,7 @@ cerebras_models: List = []
|
|||
galadriel_models: List = []
|
||||
sambanova_models: List = []
|
||||
assemblyai_models: List = []
|
||||
snowflake_models: List = []
|
||||
|
||||
|
||||
def is_bedrock_pricing_only_model(key: str) -> bool:
|
||||
|
@ -565,6 +571,8 @@ def add_known_models():
|
|||
assemblyai_models.append(key)
|
||||
elif value.get("litellm_provider") == "jina_ai":
|
||||
jina_ai_models.append(key)
|
||||
elif value.get("litellm_provider") == "snowflake":
|
||||
snowflake_models.append(key)
|
||||
|
||||
|
||||
add_known_models()
|
||||
|
@ -594,6 +602,7 @@ ollama_models = ["llama2"]
|
|||
|
||||
maritalk_models = ["maritalk"]
|
||||
|
||||
|
||||
model_list = (
|
||||
open_ai_chat_completion_models
|
||||
+ open_ai_text_completion_models
|
||||
|
@ -638,6 +647,7 @@ model_list = (
|
|||
+ azure_text_models
|
||||
+ assemblyai_models
|
||||
+ jina_ai_models
|
||||
+ snowflake_models
|
||||
)
|
||||
|
||||
model_list_set = set(model_list)
|
||||
|
@ -693,6 +703,7 @@ models_by_provider: dict = {
|
|||
"sambanova": sambanova_models,
|
||||
"assemblyai": assemblyai_models,
|
||||
"jina_ai": jina_ai_models,
|
||||
"snowflake": snowflake_models,
|
||||
}
|
||||
|
||||
# mapping for those models which have larger equivalents
|
||||
|
@ -809,6 +820,7 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
|||
from .llms.predibase.chat.transformation import PredibaseConfig
|
||||
from .llms.replicate.chat.transformation import ReplicateConfig
|
||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||
from .llms.snowflake.chat.transformation import SnowflakeConfig
|
||||
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config
|
||||
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
||||
|
@ -899,6 +911,7 @@ from .llms.bedrock.chat.invoke_transformations.base_invoke_transformation import
|
|||
|
||||
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
|
||||
from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config
|
||||
from .llms.bedrock.image.amazon_nova_canvas_transformation import AmazonNovaCanvasConfig
|
||||
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
|
||||
AmazonTitanMultimodalEmbeddingG1Config,
|
||||
|
@ -921,11 +934,14 @@ from .llms.groq.chat.transformation import GroqChatConfig
|
|||
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from .llms.openai.chat.o_series_transformation import (
|
||||
OpenAIOSeriesConfig as OpenAIO1Config, # maintain backwards compatibility
|
||||
OpenAIOSeriesConfig,
|
||||
)
|
||||
|
||||
from .llms.snowflake.chat.transformation import SnowflakeConfig
|
||||
|
||||
openaiOSeriesConfig = OpenAIOSeriesConfig()
|
||||
from .llms.openai.chat.gpt_transformation import (
|
||||
OpenAIGPTConfig,
|
||||
|
@ -1010,6 +1026,7 @@ from .batches.main import *
|
|||
from .batch_completion.main import * # type: ignore
|
||||
from .rerank_api.main import *
|
||||
from .llms.anthropic.experimental_pass_through.messages.handler import *
|
||||
from .responses.main import *
|
||||
from .realtime_api.main import _arealtime
|
||||
from .fine_tuning.main import *
|
||||
from .files.main import *
|
||||
|
|
|
@ -182,9 +182,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
|||
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"init_redis_cluster: startup nodes are being initialized."
|
||||
)
|
||||
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
|
@ -307,7 +305,6 @@ def get_redis_async_client(
|
|||
return _init_async_redis_sentinel(redis_kwargs)
|
||||
|
||||
return async_redis.Redis(
|
||||
socket_timeout=5,
|
||||
**redis_kwargs,
|
||||
)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ import litellm
|
|||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import (
|
||||
exception_type,
|
||||
get_litellm_params,
|
||||
get_llm_provider,
|
||||
get_secret,
|
||||
supports_httpx_timeout,
|
||||
|
@ -86,6 +87,7 @@ def get_assistants(
|
|||
optional_params = GenericLiteLLMParams(
|
||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||
)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -169,6 +171,7 @@ def get_assistants(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
aget_assistants=aget_assistants, # type: ignore
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -270,6 +273,7 @@ def create_assistants(
|
|||
optional_params = GenericLiteLLMParams(
|
||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||
)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -371,6 +375,7 @@ def create_assistants(
|
|||
client=client,
|
||||
async_create_assistants=async_create_assistants,
|
||||
create_assistant_data=create_assistant_data,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -445,6 +450,8 @@ def delete_assistant(
|
|||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||
)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
async_delete_assistants: Optional[bool] = kwargs.pop(
|
||||
"async_delete_assistants", None
|
||||
)
|
||||
|
@ -544,6 +551,7 @@ def delete_assistant(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
async_delete_assistants=async_delete_assistants,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -639,6 +647,7 @@ def create_thread(
|
|||
"""
|
||||
acreate_thread = kwargs.get("acreate_thread", None)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -731,6 +740,7 @@ def create_thread(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
acreate_thread=acreate_thread,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -795,7 +805,7 @@ def get_thread(
|
|||
"""Get the thread object, given a thread_id"""
|
||||
aget_thread = kwargs.pop("aget_thread", None)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
@ -884,6 +894,7 @@ def get_thread(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
aget_thread=aget_thread,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -972,6 +983,7 @@ def add_message(
|
|||
_message_data = MessageData(
|
||||
role=role, content=content, attachments=attachments, metadata=metadata
|
||||
)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
message_data = get_optional_params_add_message(
|
||||
|
@ -1068,6 +1080,7 @@ def add_message(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
a_add_message=a_add_message,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -1139,6 +1152,7 @@ def get_messages(
|
|||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
aget_messages = kwargs.pop("aget_messages", None)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -1225,6 +1239,7 @@ def get_messages(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
aget_messages=aget_messages,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -1337,6 +1352,7 @@ def run_thread(
|
|||
"""Run a given thread + assistant."""
|
||||
arun_thread = kwargs.pop("arun_thread", None)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -1437,6 +1453,7 @@ def run_thread(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
arun_thread=arun_thread,
|
||||
litellm_params=litellm_params_dict,
|
||||
) # type: ignore
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
|
|
@ -111,6 +111,7 @@ def create_batch(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
_is_async = kwargs.pop("acreate_batch", False) is True
|
||||
litellm_params = get_litellm_params(**kwargs)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -217,6 +218,7 @@ def create_batch(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_batch_data=_create_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
|
@ -320,15 +322,12 @@ def retrieve_batch(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_call_id=kwargs.get("litellm_call_id", None),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
|
@ -424,6 +423,7 @@ def retrieve_batch(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
retrieve_batch_data=_retrieve_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
|
@ -526,6 +526,10 @@ def list_batches(
|
|||
try:
|
||||
# set API KEY
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
|
@ -603,6 +607,7 @@ def list_batches(
|
|||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -678,6 +683,10 @@ def cancel_batch(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
@ -765,6 +774,7 @@ def cancel_batch(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
cancel_batch_data=_cancel_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
|
|
@ -790,6 +790,7 @@ class LLMCachingHandler:
|
|||
- Else append the chunk to self.async_streaming_chunks
|
||||
|
||||
"""
|
||||
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = _assemble_complete_response_from_streaming_chunks(
|
||||
|
@ -800,7 +801,6 @@ class LLMCachingHandler:
|
|||
streaming_chunks=self.async_streaming_chunks,
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
# if a complete_streaming_response is assembled, add it to the cache
|
||||
if complete_streaming_response is not None:
|
||||
await self.async_set_cache(
|
||||
|
|
40
litellm/caching/llm_caching_handler.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from .in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
class LLMClientCache(InMemoryCache):
|
||||
|
||||
def update_cache_key_with_event_loop(self, key):
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
If none, use the key as is.
|
||||
"""
|
||||
try:
|
||||
event_loop = asyncio.get_event_loop()
|
||||
stringified_event_loop = str(id(event_loop))
|
||||
return f"{key}-{stringified_event_loop}"
|
||||
except Exception: # handle no current event loop
|
||||
return key
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return super().set_cache(key, value, **kwargs)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return await super().async_set_cache(key, value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return super().get_cache(key, **kwargs)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return await super().async_get_cache(key, **kwargs)
|
|
@ -54,6 +54,7 @@ class RedisCache(BaseCache):
|
|||
redis_flush_size: Optional[int] = 100,
|
||||
namespace: Optional[str] = None,
|
||||
startup_nodes: Optional[List] = None, # for redis-cluster
|
||||
socket_timeout: Optional[float] = 5.0, # default 5 second timeout
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
|
@ -70,6 +71,9 @@ class RedisCache(BaseCache):
|
|||
redis_kwargs["password"] = password
|
||||
if startup_nodes is not None:
|
||||
redis_kwargs["startup_nodes"] = startup_nodes
|
||||
if socket_timeout is not None:
|
||||
redis_kwargs["socket_timeout"] = socket_timeout
|
||||
|
||||
### HEALTH MONITORING OBJECT ###
|
||||
if kwargs.get("service_logger_obj", None) is not None and isinstance(
|
||||
kwargs["service_logger_obj"], ServiceLogging
|
||||
|
@ -556,6 +560,7 @@ class RedisCache(BaseCache):
|
|||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
|
|
|
@ -7,6 +7,7 @@ DEFAULT_MAX_RETRIES = 2
|
|||
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
||||
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
||||
)
|
||||
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||
DEFAULT_COOLDOWN_TIME_SECONDS = 5
|
||||
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
||||
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
|
||||
|
@ -18,6 +19,7 @@ SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests
|
|||
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
||||
#### Networking settings ####
|
||||
request_timeout: float = 6000 # time in seconds
|
||||
STREAM_SSE_DONE_STRING: str = "[DONE]"
|
||||
|
||||
LITELLM_CHAT_PROVIDERS = [
|
||||
"openai",
|
||||
|
|
|
@ -44,7 +44,12 @@ from litellm.llms.vertex_ai.cost_calculator import cost_router as google_cost_ro
|
|||
from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
||||
cost_calculator as vertex_ai_image_cost_calculator,
|
||||
)
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||
from litellm.types.llms.openai import (
|
||||
HttpxBinaryResponseContent,
|
||||
ResponseAPIUsage,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
||||
from litellm.types.utils import (
|
||||
CallTypesLiteral,
|
||||
|
@ -464,6 +469,13 @@ def _get_usage_object(
|
|||
return usage_obj
|
||||
|
||||
|
||||
def _is_known_usage_objects(usage_obj):
|
||||
"""Returns True if the usage obj is a known Usage type"""
|
||||
return isinstance(usage_obj, litellm.Usage) or isinstance(
|
||||
usage_obj, ResponseAPIUsage
|
||||
)
|
||||
|
||||
|
||||
def _infer_call_type(
|
||||
call_type: Optional[CallTypesLiteral], completion_response: Any
|
||||
) -> Optional[CallTypesLiteral]:
|
||||
|
@ -573,9 +585,7 @@ def completion_cost( # noqa: PLR0915
|
|||
base_model=base_model,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"completion_response _select_model_name_for_cost_calc: {model}"
|
||||
)
|
||||
verbose_logger.info(f"selected model name for cost calculation: {model}")
|
||||
|
||||
if completion_response is not None and (
|
||||
isinstance(completion_response, BaseModel)
|
||||
|
@ -587,8 +597,8 @@ def completion_cost( # noqa: PLR0915
|
|||
)
|
||||
else:
|
||||
usage_obj = getattr(completion_response, "usage", {})
|
||||
if isinstance(usage_obj, BaseModel) and not isinstance(
|
||||
usage_obj, litellm.Usage
|
||||
if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
|
||||
usage_obj=usage_obj
|
||||
):
|
||||
setattr(
|
||||
completion_response,
|
||||
|
@ -601,6 +611,14 @@ def completion_cost( # noqa: PLR0915
|
|||
_usage = usage_obj.model_dump()
|
||||
else:
|
||||
_usage = usage_obj
|
||||
|
||||
if ResponseAPILoggingUtils._is_response_api_usage(_usage):
|
||||
_usage = (
|
||||
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||
_usage
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# get input/output tokens from completion_response
|
||||
prompt_tokens = _usage.get("prompt_tokens", 0)
|
||||
completion_tokens = _usage.get("completion_tokens", 0)
|
||||
|
@ -790,6 +808,23 @@ def completion_cost( # noqa: PLR0915
|
|||
raise e
|
||||
|
||||
|
||||
def get_response_cost_from_hidden_params(
|
||||
hidden_params: Union[dict, BaseModel]
|
||||
) -> Optional[float]:
|
||||
if isinstance(hidden_params, BaseModel):
|
||||
_hidden_params_dict = hidden_params.model_dump()
|
||||
else:
|
||||
_hidden_params_dict = hidden_params
|
||||
|
||||
additional_headers = _hidden_params_dict.get("additional_headers", {})
|
||||
if additional_headers and "x-litellm-response-cost" in additional_headers:
|
||||
response_cost = additional_headers["x-litellm-response-cost"]
|
||||
if response_cost is None:
|
||||
return None
|
||||
return float(additional_headers["x-litellm-response-cost"])
|
||||
return None
|
||||
|
||||
|
||||
def response_cost_calculator(
|
||||
response_object: Union[
|
||||
ModelResponse,
|
||||
|
@ -799,6 +834,7 @@ def response_cost_calculator(
|
|||
TextCompletionResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
RerankResponse,
|
||||
ResponsesAPIResponse,
|
||||
],
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str],
|
||||
|
@ -825,7 +861,7 @@ def response_cost_calculator(
|
|||
base_model: Optional[str] = None,
|
||||
custom_pricing: Optional[bool] = None,
|
||||
prompt: str = "",
|
||||
) -> Optional[float]:
|
||||
) -> float:
|
||||
"""
|
||||
Returns
|
||||
- float or None: cost of response
|
||||
|
@ -837,6 +873,14 @@ def response_cost_calculator(
|
|||
else:
|
||||
if isinstance(response_object, BaseModel):
|
||||
response_object._hidden_params["optional_params"] = optional_params
|
||||
|
||||
if hasattr(response_object, "_hidden_params"):
|
||||
provider_response_cost = get_response_cost_from_hidden_params(
|
||||
response_object._hidden_params
|
||||
)
|
||||
if provider_response_cost is not None:
|
||||
return provider_response_cost
|
||||
|
||||
response_cost = completion_cost(
|
||||
completion_response=response_object,
|
||||
model=model,
|
||||
|
|
|
@ -25,7 +25,7 @@ from litellm.types.llms.openai import (
|
|||
HttpxBinaryResponseContent,
|
||||
)
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
from litellm.utils import get_litellm_params, supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_files_instance = OpenAIFilesAPI()
|
||||
|
@ -546,6 +546,7 @@ def create_file(
|
|||
try:
|
||||
_is_async = kwargs.pop("acreate_file", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -630,6 +631,7 @@ def create_file(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
|
|
|
@ -1,31 +1,37 @@
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
def set_attributes(span: Span, kwargs, response_obj):
|
||||
from openinference.semconv.trace import (
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
OpenInferenceSpanKindValues,
|
||||
SpanAttributes,
|
||||
)
|
||||
|
||||
try:
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
|
||||
#############################################
|
||||
############ LLM CALL METADATA ##############
|
||||
#############################################
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
span.set_attribute(SpanAttributes.METADATA, str(metadata))
|
||||
|
||||
if standard_logging_payload and (
|
||||
metadata := standard_logging_payload["metadata"]
|
||||
):
|
||||
span.set_attribute(SpanAttributes.METADATA, safe_dumps(metadata))
|
||||
|
||||
#############################################
|
||||
########## LLM Request Attributes ###########
|
||||
|
@ -62,13 +68,12 @@ def set_attributes(span: Span, kwargs, response_obj):
|
|||
msg.get("content", ""),
|
||||
)
|
||||
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload and (model_params := standard_logging_payload["model_parameters"]):
|
||||
if standard_logging_payload and (
|
||||
model_params := standard_logging_payload["model_parameters"]
|
||||
):
|
||||
# The Generative AI Provider: Azure, OpenAI, etc.
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS, json.dumps(model_params)
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params)
|
||||
)
|
||||
|
||||
if model_params.get("user"):
|
||||
|
@ -80,7 +85,7 @@ def set_attributes(span: Span, kwargs, response_obj):
|
|||
########## LLM Response Attributes ##########
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
#############################################
|
||||
if hasattr(response_obj, 'get'):
|
||||
if hasattr(response_obj, "get"):
|
||||
for choice in response_obj.get("choices", []):
|
||||
response_message = choice.get("message", {})
|
||||
span.set_attribute(
|
||||
|
|
|
@ -3,31 +3,38 @@ arize AI is OTEL compatible
|
|||
|
||||
this file has Arize ai specific helper functions
|
||||
"""
|
||||
import os
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm.integrations.arize import _utils
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
from litellm.types.integrations.arize import ArizeConfig
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
Protocol = _Protocol
|
||||
Span = _Span
|
||||
else:
|
||||
Protocol = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
|
||||
class ArizeLogger:
|
||||
class ArizeLogger(OpenTelemetry):
|
||||
|
||||
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def set_arize_attributes(span: Span, kwargs, response_obj):
|
||||
_utils.set_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_arize_config() -> ArizeConfig:
|
||||
|
@ -43,11 +50,6 @@ class ArizeLogger:
|
|||
space_key = os.environ.get("ARIZE_SPACE_KEY")
|
||||
api_key = os.environ.get("ARIZE_API_KEY")
|
||||
|
||||
if not space_key:
|
||||
raise ValueError("ARIZE_SPACE_KEY not found in environment variables")
|
||||
if not api_key:
|
||||
raise ValueError("ARIZE_API_KEY not found in environment variables")
|
||||
|
||||
grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
|
||||
http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
|
||||
|
||||
|
@ -55,13 +57,13 @@ class ArizeLogger:
|
|||
protocol: Protocol = "otlp_grpc"
|
||||
|
||||
if grpc_endpoint:
|
||||
protocol="otlp_grpc"
|
||||
endpoint=grpc_endpoint
|
||||
protocol = "otlp_grpc"
|
||||
endpoint = grpc_endpoint
|
||||
elif http_endpoint:
|
||||
protocol="otlp_http"
|
||||
endpoint=http_endpoint
|
||||
protocol = "otlp_http"
|
||||
endpoint = http_endpoint
|
||||
else:
|
||||
protocol="otlp_grpc"
|
||||
protocol = "otlp_grpc"
|
||||
endpoint = "https://otlp.arize.com/v1"
|
||||
|
||||
return ArizeConfig(
|
||||
|
@ -71,4 +73,33 @@ class ArizeLogger:
|
|||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Optional[str] = "",
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
def create_litellm_proxy_request_started_span(
|
||||
self,
|
||||
start_time: datetime,
|
||||
headers: dict,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,16 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Promptlayer
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -14,6 +23,7 @@ from litellm.types.utils import (
|
|||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
@ -239,6 +249,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
) -> Any:
|
||||
pass
|
||||
|
@ -250,6 +261,15 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
) -> Any:
|
||||
pass
|
||||
|
||||
async def async_post_call_streaming_iterator_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Any,
|
||||
request_data: dict,
|
||||
) -> AsyncGenerator[ModelResponseStream, None]:
|
||||
async for item in response:
|
||||
yield item
|
||||
|
||||
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||
|
||||
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||
|
|
49
litellm/integrations/custom_prompt_management.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.prompt_management_base import (
|
||||
PromptManagementBase,
|
||||
PromptManagementClient,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
|
||||
class CustomPromptManagement(CustomLogger, PromptManagementBase):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Returns:
|
||||
- model: str - the model to use (can be pulled from prompt management tool)
|
||||
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
|
||||
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
|
||||
"""
|
||||
return model, messages, non_default_params
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
return "custom-prompt-management"
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: str,
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> PromptManagementClient:
|
||||
raise NotImplementedError(
|
||||
"Custom prompt management does not support compile prompt helper"
|
||||
)
|
|
@ -10,6 +10,7 @@ from litellm.types.services import ServiceLoggerPayload
|
|||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
|
@ -311,6 +312,8 @@ class OpenTelemetry(CustomLogger):
|
|||
)
|
||||
_parent_context, parent_otel_span = self._get_span_context(kwargs)
|
||||
|
||||
self._add_dynamic_span_processor_if_needed(kwargs)
|
||||
|
||||
# Span 1: Requst sent to litellm SDK
|
||||
span = self.tracer.start_span(
|
||||
name=self._get_span_name(kwargs),
|
||||
|
@ -341,6 +344,45 @@ class OpenTelemetry(CustomLogger):
|
|||
if parent_otel_span is not None:
|
||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
def _add_dynamic_span_processor_if_needed(self, kwargs):
|
||||
"""
|
||||
Helper method to add a span processor with dynamic headers if needed.
|
||||
|
||||
This allows for per-request configuration of telemetry exporters by
|
||||
extracting headers from standard_callback_dynamic_params.
|
||||
"""
|
||||
from opentelemetry import trace
|
||||
|
||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||
kwargs.get("standard_callback_dynamic_params")
|
||||
)
|
||||
if not standard_callback_dynamic_params:
|
||||
return
|
||||
|
||||
# Extract headers from dynamic params
|
||||
dynamic_headers = {}
|
||||
|
||||
# Handle Arize headers
|
||||
if standard_callback_dynamic_params.get("arize_space_key"):
|
||||
dynamic_headers["space_key"] = standard_callback_dynamic_params.get(
|
||||
"arize_space_key"
|
||||
)
|
||||
if standard_callback_dynamic_params.get("arize_api_key"):
|
||||
dynamic_headers["api_key"] = standard_callback_dynamic_params.get(
|
||||
"arize_api_key"
|
||||
)
|
||||
|
||||
# Only create a span processor if we have headers to use
|
||||
if len(dynamic_headers) > 0:
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
|
||||
provider = trace.get_tracer_provider()
|
||||
if isinstance(provider, TracerProvider):
|
||||
span_processor = self._get_span_processor(
|
||||
dynamic_headers=dynamic_headers
|
||||
)
|
||||
provider.add_span_processor(span_processor)
|
||||
|
||||
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
|
@ -443,14 +485,12 @@ class OpenTelemetry(CustomLogger):
|
|||
self, span: Span, kwargs, response_obj: Optional[Any]
|
||||
):
|
||||
try:
|
||||
if self.callback_name == "arize":
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
elif self.callback_name == "arize_phoenix":
|
||||
if self.callback_name == "arize_phoenix":
|
||||
from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger
|
||||
|
||||
ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj)
|
||||
ArizePhoenixLogger.set_arize_phoenix_attributes(
|
||||
span, kwargs, response_obj
|
||||
)
|
||||
return
|
||||
elif self.callback_name == "langtrace":
|
||||
from litellm.integrations.langtrace import LangtraceAttributes
|
||||
|
@ -779,7 +819,7 @@ class OpenTelemetry(CustomLogger):
|
|||
carrier = {"traceparent": traceparent}
|
||||
return TraceContextTextMapPropagator().extract(carrier=carrier), None
|
||||
|
||||
def _get_span_processor(self):
|
||||
def _get_span_processor(self, dynamic_headers: Optional[dict] = None):
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
||||
OTLPSpanExporter as OTLPSpanExporterGRPC,
|
||||
)
|
||||
|
@ -799,10 +839,9 @@ class OpenTelemetry(CustomLogger):
|
|||
self.OTEL_ENDPOINT,
|
||||
self.OTEL_HEADERS,
|
||||
)
|
||||
_split_otel_headers = {}
|
||||
if self.OTEL_HEADERS is not None and isinstance(self.OTEL_HEADERS, str):
|
||||
_split_otel_headers = self.OTEL_HEADERS.split("=")
|
||||
_split_otel_headers = {_split_otel_headers[0]: _split_otel_headers[1]}
|
||||
_split_otel_headers = OpenTelemetry._get_headers_dictionary(
|
||||
headers=dynamic_headers or self.OTEL_HEADERS
|
||||
)
|
||||
|
||||
if isinstance(self.OTEL_EXPORTER, SpanExporter):
|
||||
verbose_logger.debug(
|
||||
|
@ -844,6 +883,25 @@ class OpenTelemetry(CustomLogger):
|
|||
)
|
||||
return BatchSpanProcessor(ConsoleSpanExporter())
|
||||
|
||||
@staticmethod
|
||||
def _get_headers_dictionary(headers: Optional[Union[str, dict]]) -> Dict[str, str]:
|
||||
"""
|
||||
Convert a string or dictionary of headers into a dictionary of headers.
|
||||
"""
|
||||
_split_otel_headers: Dict[str, str] = {}
|
||||
if headers:
|
||||
if isinstance(headers, str):
|
||||
# when passed HEADERS="x-honeycomb-team=B85YgLm96******"
|
||||
# Split only on first '=' occurrence
|
||||
parts = headers.split("=", 1)
|
||||
if len(parts) == 2:
|
||||
_split_otel_headers = {parts[0]: parts[1]}
|
||||
else:
|
||||
_split_otel_headers = {}
|
||||
elif isinstance(headers, dict):
|
||||
_split_otel_headers = headers
|
||||
return _split_otel_headers
|
||||
|
||||
async def async_management_endpoint_success_hook(
|
||||
self,
|
||||
logging_payload: ManagementEndpointLoggingPayload,
|
||||
|
@ -948,3 +1006,18 @@ class OpenTelemetry(CustomLogger):
|
|||
)
|
||||
management_endpoint_span.set_status(Status(StatusCode.ERROR))
|
||||
management_endpoint_span.end(end_time=_end_time_ns)
|
||||
|
||||
def create_litellm_proxy_request_started_span(
|
||||
self,
|
||||
start_time: datetime,
|
||||
headers: dict,
|
||||
) -> Optional[Span]:
|
||||
"""
|
||||
Create a span for the received proxy server request.
|
||||
"""
|
||||
return self.tracer.start_span(
|
||||
name="Received Proxy Server Request",
|
||||
start_time=self._to_ns(start_time),
|
||||
context=self.get_traceparent_from_header(headers=headers),
|
||||
kind=self.span_kind.SERVER,
|
||||
)
|
||||
|
|
34
litellm/litellm_core_utils/credential_accessor.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
"""Utils for accessing credentials."""
|
||||
|
||||
from typing import List
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import CredentialItem
|
||||
|
||||
|
||||
class CredentialAccessor:
|
||||
@staticmethod
|
||||
def get_credential_values(credential_name: str) -> dict:
|
||||
"""Safe accessor for credentials."""
|
||||
if not litellm.credential_list:
|
||||
return {}
|
||||
for credential in litellm.credential_list:
|
||||
if credential.credential_name == credential_name:
|
||||
return credential.credential_values.copy()
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def upsert_credentials(credentials: List[CredentialItem]):
|
||||
"""Add a credential to the list of credentials."""
|
||||
|
||||
credential_names = [cred.credential_name for cred in litellm.credential_list]
|
||||
|
||||
for credential in credentials:
|
||||
if credential.credential_name in credential_names:
|
||||
# Find and replace the existing credential in the list
|
||||
for i, existing_cred in enumerate(litellm.credential_list):
|
||||
if existing_cred.credential_name == credential.credential_name:
|
||||
litellm.credential_list[i] = credential
|
||||
break
|
||||
else:
|
||||
litellm.credential_list.append(credential)
|
|
@ -127,7 +127,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
completion_kwargs={},
|
||||
extra_kwargs={},
|
||||
):
|
||||
|
||||
"""Maps an LLM Provider Exception to OpenAI Exception Format"""
|
||||
if any(
|
||||
isinstance(original_exception, exc_type)
|
||||
for exc_type in litellm.LITELLM_EXCEPTION_TYPES
|
||||
|
|
|
@ -58,6 +58,8 @@ def get_litellm_params(
|
|||
async_call: Optional[bool] = None,
|
||||
ssl_verify: Optional[bool] = None,
|
||||
merge_reasoning_content_in_choices: Optional[bool] = None,
|
||||
api_version: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
litellm_params = {
|
||||
|
@ -99,5 +101,14 @@ def get_litellm_params(
|
|||
"async_call": async_call,
|
||||
"ssl_verify": ssl_verify,
|
||||
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": kwargs.get("azure_ad_token"),
|
||||
"tenant_id": kwargs.get("tenant_id"),
|
||||
"client_id": kwargs.get("client_id"),
|
||||
"client_secret": kwargs.get("client_secret"),
|
||||
"azure_username": kwargs.get("azure_username"),
|
||||
"azure_password": kwargs.get("azure_password"),
|
||||
"max_retries": max_retries,
|
||||
"timeout": kwargs.get("timeout"),
|
||||
}
|
||||
return litellm_params
|
||||
|
|
|
@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915
|
|||
model, custom_llm_provider
|
||||
)
|
||||
|
||||
if custom_llm_provider:
|
||||
if (
|
||||
model.split("/")[0] == custom_llm_provider
|
||||
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
|
||||
model = model.replace("{}/".format(custom_llm_provider), "")
|
||||
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
if custom_llm_provider and (
|
||||
model.split("/")[0] != custom_llm_provider
|
||||
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
|
||||
model = custom_llm_provider + "/" + model
|
||||
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
dynamic_api_key = get_secret_str(api_key)
|
||||
# check if llm provider part of model name
|
||||
|
||||
if (
|
||||
model.split("/", 1)[0] in litellm.provider_list
|
||||
and model.split("/", 1)[0] not in litellm.model_list_set
|
||||
|
@ -571,6 +569,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
|||
or "https://api.galadriel.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
||||
elif custom_llm_provider == "snowflake":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("SNOWFLAKE_API_BASE")
|
||||
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
|
||||
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
|
|
|
@ -29,6 +29,7 @@ from litellm.batches.batch_utils import _handle_completed_batch
|
|||
from litellm.caching.caching import DualCache, InMemoryCache
|
||||
from litellm.caching.caching_handler import LLMCachingHandler
|
||||
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.mlflow import MlflowLogger
|
||||
|
@ -39,11 +40,14 @@ from litellm.litellm_core_utils.redact_messages import (
|
|||
redact_message_input_output_from_custom_logger,
|
||||
redact_message_input_output_from_logging,
|
||||
)
|
||||
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
Batch,
|
||||
FineTuningJob,
|
||||
HttpxBinaryResponseContent,
|
||||
ResponseCompletedEvent,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||
|
@ -73,11 +77,11 @@ from litellm.types.utils import (
|
|||
from litellm.utils import _get_base_model_from_metadata, executor, print_verbose
|
||||
|
||||
from ..integrations.argilla import ArgillaLogger
|
||||
from ..integrations.arize.arize import ArizeLogger
|
||||
from ..integrations.arize.arize_phoenix import ArizePhoenixLogger
|
||||
from ..integrations.athina import AthinaLogger
|
||||
from ..integrations.azure_storage.azure_storage import AzureBlobStorageLogger
|
||||
from ..integrations.braintrust_logging import BraintrustLogger
|
||||
from ..integrations.custom_prompt_management import CustomPromptManagement
|
||||
from ..integrations.datadog.datadog import DataDogLogger
|
||||
from ..integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger
|
||||
from ..integrations.dynamodb import DyanmoDBLogger
|
||||
|
@ -107,7 +111,6 @@ from .exception_mapping_utils import _get_response_headers
|
|||
from .initialize_dynamic_callback_params import (
|
||||
initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params,
|
||||
)
|
||||
from .logging_utils import _assemble_complete_response_from_streaming_chunks
|
||||
from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache
|
||||
|
||||
try:
|
||||
|
@ -427,34 +430,58 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
prompt_variables: Optional[dict],
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
|
||||
for (
|
||||
custom_logger_compatible_callback
|
||||
) in litellm._known_custom_logger_compatible_callbacks:
|
||||
if model.startswith(custom_logger_compatible_callback):
|
||||
custom_logger = self.get_custom_logger_for_prompt_management(model)
|
||||
if custom_logger:
|
||||
model, messages, non_default_params = (
|
||||
custom_logger.get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||
)
|
||||
)
|
||||
self.messages = messages
|
||||
return model, messages, non_default_params
|
||||
|
||||
def get_custom_logger_for_prompt_management(
|
||||
self, model: str
|
||||
) -> Optional[CustomLogger]:
|
||||
"""
|
||||
Get a custom logger for prompt management based on model name or available callbacks.
|
||||
|
||||
Args:
|
||||
model: The model name to check for prompt management integration
|
||||
|
||||
Returns:
|
||||
A CustomLogger instance if one is found, None otherwise
|
||||
"""
|
||||
# First check if model starts with a known custom logger compatible callback
|
||||
for callback_name in litellm._known_custom_logger_compatible_callbacks:
|
||||
if model.startswith(callback_name):
|
||||
custom_logger = _init_custom_logger_compatible_class(
|
||||
logging_integration=custom_logger_compatible_callback,
|
||||
logging_integration=callback_name,
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
if custom_logger is not None:
|
||||
self.model_call_details["prompt_integration"] = model.split("/")[0]
|
||||
return custom_logger
|
||||
|
||||
if custom_logger is None:
|
||||
continue
|
||||
old_name = model
|
||||
# Then check for any registered CustomPromptManagement loggers
|
||||
prompt_management_loggers = (
|
||||
litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=CustomPromptManagement
|
||||
)
|
||||
)
|
||||
|
||||
model, messages, non_default_params = (
|
||||
custom_logger.get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||
)
|
||||
)
|
||||
self.model_call_details["prompt_integration"] = old_name.split("/")[0]
|
||||
self.messages = messages
|
||||
if prompt_management_loggers:
|
||||
logger = prompt_management_loggers[0]
|
||||
self.model_call_details["prompt_integration"] = logger.__class__.__name__
|
||||
return logger
|
||||
|
||||
return model, messages, non_default_params
|
||||
return None
|
||||
|
||||
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
|
||||
if data is None:
|
||||
|
@ -716,25 +743,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
Masks the headers of the request sent from LiteLLM
|
||||
"""
|
||||
sensitive_keywords = [
|
||||
"authorization",
|
||||
"token",
|
||||
"key",
|
||||
"secret",
|
||||
]
|
||||
return {
|
||||
k: (
|
||||
(v[:-44] + "*" * 44)
|
||||
if (isinstance(v, str) and len(v) > 44)
|
||||
else "*****"
|
||||
)
|
||||
for k, v in headers.items()
|
||||
if not ignore_sensitive_headers
|
||||
or not any(
|
||||
sensitive_keyword in k.lower()
|
||||
for sensitive_keyword in sensitive_keywords
|
||||
)
|
||||
}
|
||||
return _get_masked_values(
|
||||
headers, ignore_sensitive_values=ignore_sensitive_headers
|
||||
)
|
||||
|
||||
def post_call(
|
||||
self, original_response, input=None, api_key=None, additional_args={}
|
||||
|
@ -851,6 +862,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
RerankResponse,
|
||||
Batch,
|
||||
FineTuningJob,
|
||||
ResponsesAPIResponse,
|
||||
ResponseCompletedEvent,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
) -> Optional[float]:
|
||||
|
@ -1000,7 +1013,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
standard_logging_object is None
|
||||
and result is not None
|
||||
and self.stream is not True
|
||||
): # handle streaming separately
|
||||
):
|
||||
if (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, ModelResponseStream)
|
||||
|
@ -1012,6 +1025,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
or isinstance(result, RerankResponse)
|
||||
or isinstance(result, FineTuningJob)
|
||||
or isinstance(result, LiteLLMBatch)
|
||||
or isinstance(result, ResponsesAPIResponse)
|
||||
):
|
||||
## HIDDEN PARAMS ##
|
||||
hidden_params = getattr(result, "_hidden_params", {})
|
||||
|
@ -1111,7 +1125,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
## BUILD COMPLETE STREAMED RESPONSE
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
|
||||
] = None
|
||||
if "complete_streaming_response" in self.model_call_details:
|
||||
return # break out of this.
|
||||
|
@ -1633,7 +1647,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
if "async_complete_streaming_response" in self.model_call_details:
|
||||
return # break out of this.
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
|
||||
] = self._get_assembled_streaming_response(
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
|
@ -2343,28 +2357,24 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
def _get_assembled_streaming_response(
|
||||
self,
|
||||
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream, Any],
|
||||
result: Union[
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
ModelResponseStream,
|
||||
ResponseCompletedEvent,
|
||||
Any,
|
||||
],
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
is_async: bool,
|
||||
streaming_chunks: List[Any],
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]]:
|
||||
if isinstance(result, ModelResponse):
|
||||
return result
|
||||
elif isinstance(result, TextCompletionResponse):
|
||||
return result
|
||||
elif isinstance(result, ModelResponseStream):
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = _assemble_complete_response_from_streaming_chunks(
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
request_kwargs=self.model_call_details,
|
||||
streaming_chunks=streaming_chunks,
|
||||
is_async=is_async,
|
||||
)
|
||||
return complete_streaming_response
|
||||
elif isinstance(result, ResponseCompletedEvent):
|
||||
return result.response
|
||||
return None
|
||||
|
||||
def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse:
|
||||
|
@ -2399,6 +2409,58 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
return result
|
||||
|
||||
|
||||
def _get_masked_values(
|
||||
sensitive_object: dict,
|
||||
ignore_sensitive_values: bool = False,
|
||||
mask_all_values: bool = False,
|
||||
unmasked_length: int = 4,
|
||||
number_of_asterisks: Optional[int] = 4,
|
||||
) -> dict:
|
||||
"""
|
||||
Internal debugging helper function
|
||||
|
||||
Masks the headers of the request sent from LiteLLM
|
||||
|
||||
Args:
|
||||
masked_length: Optional length for the masked portion (number of *). If set, will use exactly this many *
|
||||
regardless of original string length. The total length will be unmasked_length + masked_length.
|
||||
"""
|
||||
sensitive_keywords = [
|
||||
"authorization",
|
||||
"token",
|
||||
"key",
|
||||
"secret",
|
||||
]
|
||||
return {
|
||||
k: (
|
||||
(
|
||||
v[: unmasked_length // 2]
|
||||
+ "*" * number_of_asterisks
|
||||
+ v[-unmasked_length // 2 :]
|
||||
)
|
||||
if (
|
||||
isinstance(v, str)
|
||||
and len(v) > unmasked_length
|
||||
and number_of_asterisks is not None
|
||||
)
|
||||
else (
|
||||
(
|
||||
v[: unmasked_length // 2]
|
||||
+ "*" * (len(v) - unmasked_length)
|
||||
+ v[-unmasked_length // 2 :]
|
||||
)
|
||||
if (isinstance(v, str) and len(v) > unmasked_length)
|
||||
else "*****"
|
||||
)
|
||||
)
|
||||
for k, v in sensitive_object.items()
|
||||
if not ignore_sensitive_values
|
||||
or not any(
|
||||
sensitive_keyword in k.lower() for sensitive_keyword in sensitive_keywords
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
|
||||
"""
|
||||
Globally sets the callback client
|
||||
|
@ -2621,13 +2683,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
)
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, OpenTelemetry)
|
||||
isinstance(callback, ArizeLogger)
|
||||
and callback.callback_name == "arize"
|
||||
):
|
||||
return callback # type: ignore
|
||||
_otel_logger = OpenTelemetry(config=otel_config, callback_name="arize")
|
||||
_in_memory_loggers.append(_otel_logger)
|
||||
return _otel_logger # type: ignore
|
||||
_arize_otel_logger = ArizeLogger(config=otel_config, callback_name="arize")
|
||||
_in_memory_loggers.append(_arize_otel_logger)
|
||||
return _arize_otel_logger # type: ignore
|
||||
elif logging_integration == "arize_phoenix":
|
||||
from litellm.integrations.opentelemetry import (
|
||||
OpenTelemetry,
|
||||
|
@ -2860,15 +2922,13 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
|
|||
if isinstance(callback, OpenTelemetry):
|
||||
return callback
|
||||
elif logging_integration == "arize":
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
if "ARIZE_SPACE_KEY" not in os.environ:
|
||||
raise ValueError("ARIZE_SPACE_KEY not found in environment variables")
|
||||
if "ARIZE_API_KEY" not in os.environ:
|
||||
raise ValueError("ARIZE_API_KEY not found in environment variables")
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, OpenTelemetry)
|
||||
isinstance(callback, ArizeLogger)
|
||||
and callback.callback_name == "arize"
|
||||
):
|
||||
return callback
|
||||
|
@ -3111,6 +3171,12 @@ class StandardLoggingPayloadSetup:
|
|||
elif isinstance(usage, Usage):
|
||||
return usage
|
||||
elif isinstance(usage, dict):
|
||||
if ResponseAPILoggingUtils._is_response_api_usage(usage):
|
||||
return (
|
||||
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||
usage
|
||||
)
|
||||
)
|
||||
return Usage(**usage)
|
||||
|
||||
raise ValueError(f"usage is required, got={usage} of type {type(usage)}")
|
||||
|
@ -3215,6 +3281,7 @@ class StandardLoggingPayloadSetup:
|
|||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
)
|
||||
if hidden_params is not None:
|
||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
|
@ -3329,6 +3396,7 @@ def get_standard_logging_object_payload(
|
|||
response_cost=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -3614,6 +3682,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
|||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
)
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
|
|
|
@ -44,6 +44,7 @@ class ResponseMetadata:
|
|||
"additional_headers": process_response_headers(
|
||||
self._get_value_from_hidden_params("additional_headers") or {}
|
||||
),
|
||||
"litellm_model_name": model,
|
||||
}
|
||||
self._update_hidden_params(new_params)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, List, Set, Union
|
||||
from typing import Callable, List, Set, Type, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
@ -86,21 +86,20 @@ class LoggingCallbackManager:
|
|||
callback=callback, parent_list=litellm._async_failure_callback
|
||||
)
|
||||
|
||||
def remove_callback_from_list_by_object(
|
||||
self, callback_list, obj
|
||||
):
|
||||
def remove_callback_from_list_by_object(self, callback_list, obj):
|
||||
"""
|
||||
Remove callbacks that are methods of a particular object (e.g., router cleanup)
|
||||
"""
|
||||
if not isinstance(callback_list, list): # Not list -> do nothing
|
||||
if not isinstance(callback_list, list): # Not list -> do nothing
|
||||
return
|
||||
|
||||
remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj]
|
||||
|
||||
remove_list = [
|
||||
c for c in callback_list if hasattr(c, "__self__") and c.__self__ == obj
|
||||
]
|
||||
|
||||
for c in remove_list:
|
||||
callback_list.remove(c)
|
||||
|
||||
|
||||
def _add_string_callback_to_list(
|
||||
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||
):
|
||||
|
@ -254,3 +253,11 @@ class LoggingCallbackManager:
|
|||
):
|
||||
matched_callbacks.add(callback)
|
||||
return matched_callbacks
|
||||
|
||||
def get_custom_loggers_for_type(
|
||||
self, callback_type: Type[CustomLogger]
|
||||
) -> List[CustomLogger]:
|
||||
"""
|
||||
Get all custom loggers that are instances of the given class type
|
||||
"""
|
||||
return [c for c in self._get_all_callbacks() if isinstance(c, callback_type)]
|
||||
|
|
|
@ -77,6 +77,10 @@ def _assemble_complete_response_from_streaming_chunks(
|
|||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = None
|
||||
|
||||
if isinstance(result, ModelResponse):
|
||||
return result
|
||||
|
||||
if result.choices[0].finish_reason is not None: # if it's the last chunk
|
||||
streaming_chunks.append(result)
|
||||
try:
|
||||
|
|
|
@ -77,6 +77,16 @@ def convert_content_list_to_str(message: AllMessageValues) -> str:
|
|||
return texts
|
||||
|
||||
|
||||
def get_str_from_messages(messages: List[AllMessageValues]) -> str:
|
||||
"""
|
||||
Converts a list of messages to a string
|
||||
"""
|
||||
text = ""
|
||||
for message in messages:
|
||||
text += convert_content_list_to_str(message=message)
|
||||
return text
|
||||
|
||||
|
||||
def is_non_content_values_set(message: AllMessageValues) -> bool:
|
||||
ignore_keys = ["content", "role", "name"]
|
||||
return any(
|
||||
|
|
|
@ -166,148 +166,108 @@ def convert_to_ollama_image(openai_image_url: str):
|
|||
)
|
||||
|
||||
|
||||
def _handle_ollama_system_message(
|
||||
messages: list, prompt: str, msg_i: int
|
||||
) -> Tuple[str, int]:
|
||||
system_content_str = ""
|
||||
## MERGE CONSECUTIVE SYSTEM CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "system":
|
||||
msg_content = convert_content_list_to_str(messages[msg_i])
|
||||
system_content_str += msg_content
|
||||
|
||||
msg_i += 1
|
||||
|
||||
return system_content_str, msg_i
|
||||
|
||||
|
||||
def ollama_pt(
|
||||
model, messages
|
||||
model: str, messages: list
|
||||
) -> Union[
|
||||
str, OllamaVisionModelObject
|
||||
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||
if "instruct" in model:
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
"system": {"pre_message": "### System:\n", "post_message": "\n"},
|
||||
"user": {
|
||||
"pre_message": "### User:\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "### Response:\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
},
|
||||
final_prompt_value="### Response:",
|
||||
messages=messages,
|
||||
user_message_types = {"user", "tool", "function"}
|
||||
msg_i = 0
|
||||
images = []
|
||||
prompt = ""
|
||||
while msg_i < len(messages):
|
||||
init_msg_i = msg_i
|
||||
user_content_str = ""
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||
msg_content = messages[msg_i].get("content")
|
||||
if msg_content:
|
||||
if isinstance(msg_content, list):
|
||||
for m in msg_content:
|
||||
if m.get("type", "") == "image_url":
|
||||
if isinstance(m["image_url"], str):
|
||||
images.append(m["image_url"])
|
||||
elif isinstance(m["image_url"], dict):
|
||||
images.append(m["image_url"]["url"])
|
||||
elif m.get("type", "") == "text":
|
||||
user_content_str += m["text"]
|
||||
else:
|
||||
# Tool message content will always be a string
|
||||
user_content_str += msg_content
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if user_content_str:
|
||||
prompt += f"### User:\n{user_content_str}\n\n"
|
||||
|
||||
system_content_str, msg_i = _handle_ollama_system_message(
|
||||
messages, prompt, msg_i
|
||||
)
|
||||
else:
|
||||
user_message_types = {"user", "tool", "function"}
|
||||
msg_i = 0
|
||||
images = []
|
||||
prompt = ""
|
||||
while msg_i < len(messages):
|
||||
init_msg_i = msg_i
|
||||
user_content_str = ""
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while (
|
||||
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||
):
|
||||
msg_content = messages[msg_i].get("content")
|
||||
if msg_content:
|
||||
if isinstance(msg_content, list):
|
||||
for m in msg_content:
|
||||
if m.get("type", "") == "image_url":
|
||||
if isinstance(m["image_url"], str):
|
||||
images.append(m["image_url"])
|
||||
elif isinstance(m["image_url"], dict):
|
||||
images.append(m["image_url"]["url"])
|
||||
elif m.get("type", "") == "text":
|
||||
user_content_str += m["text"]
|
||||
else:
|
||||
# Tool message content will always be a string
|
||||
user_content_str += msg_content
|
||||
if system_content_str:
|
||||
prompt += f"### System:\n{system_content_str}\n\n"
|
||||
|
||||
msg_i += 1
|
||||
assistant_content_str = ""
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
assistant_content_str += convert_content_list_to_str(messages[msg_i])
|
||||
msg_i += 1
|
||||
|
||||
if user_content_str:
|
||||
prompt += f"### User:\n{user_content_str}\n\n"
|
||||
tool_calls = messages[msg_i].get("tool_calls")
|
||||
ollama_tool_calls = []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
call_id: str = call["id"]
|
||||
function_name: str = call["function"]["name"]
|
||||
arguments = json.loads(call["function"]["arguments"])
|
||||
|
||||
assistant_content_str = ""
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
msg_content = messages[msg_i].get("content")
|
||||
if msg_content:
|
||||
if isinstance(msg_content, list):
|
||||
for m in msg_content:
|
||||
if m.get("type", "") == "text":
|
||||
assistant_content_str += m["text"]
|
||||
elif isinstance(msg_content, str):
|
||||
# Tool message content will always be a string
|
||||
assistant_content_str += msg_content
|
||||
|
||||
tool_calls = messages[msg_i].get("tool_calls")
|
||||
ollama_tool_calls = []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
call_id: str = call["id"]
|
||||
function_name: str = call["function"]["name"]
|
||||
arguments = json.loads(call["function"]["arguments"])
|
||||
|
||||
ollama_tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if ollama_tool_calls:
|
||||
assistant_content_str += (
|
||||
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
||||
ollama_tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content_str:
|
||||
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
||||
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise litellm.BadRequestError(
|
||||
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
||||
model=model,
|
||||
llm_provider="ollama",
|
||||
if ollama_tool_calls:
|
||||
assistant_content_str += (
|
||||
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
||||
)
|
||||
# prompt = ""
|
||||
# images = []
|
||||
# for message in messages:
|
||||
# if isinstance(message["content"], str):
|
||||
# prompt += message["content"]
|
||||
# elif isinstance(message["content"], list):
|
||||
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||
# for element in message["content"]:
|
||||
# if isinstance(element, dict):
|
||||
# if element["type"] == "text":
|
||||
# prompt += element["text"]
|
||||
# elif element["type"] == "image_url":
|
||||
# base64_image = convert_to_ollama_image(
|
||||
# element["image_url"]["url"]
|
||||
# )
|
||||
# images.append(base64_image)
|
||||
|
||||
# if "tool_calls" in message:
|
||||
# tool_calls = []
|
||||
msg_i += 1
|
||||
|
||||
# for call in message["tool_calls"]:
|
||||
# call_id: str = call["id"]
|
||||
# function_name: str = call["function"]["name"]
|
||||
# arguments = json.loads(call["function"]["arguments"])
|
||||
if assistant_content_str:
|
||||
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
||||
|
||||
# tool_calls.append(
|
||||
# {
|
||||
# "id": call_id,
|
||||
# "type": "function",
|
||||
# "function": {"name": function_name, "arguments": arguments},
|
||||
# }
|
||||
# )
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise litellm.BadRequestError(
|
||||
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
||||
model=model,
|
||||
llm_provider="ollama",
|
||||
)
|
||||
|
||||
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
|
||||
response_dict: OllamaVisionModelObject = {
|
||||
"prompt": prompt,
|
||||
"images": images,
|
||||
}
|
||||
|
||||
# elif "tool_call_id" in message:
|
||||
# prompt += f"### User:\n{message['content']}\n\n"
|
||||
|
||||
return {"prompt": prompt, "images": images}
|
||||
|
||||
return prompt
|
||||
return response_dict
|
||||
|
||||
|
||||
def mistral_instruct_pt(messages):
|
||||
|
|
|
@ -13,6 +13,7 @@ from litellm.types.utils import (
|
|||
Function,
|
||||
FunctionCall,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
PromptTokensDetails,
|
||||
Usage,
|
||||
)
|
||||
|
@ -319,8 +320,12 @@ class ChunkProcessor:
|
|||
usage_chunk: Optional[Usage] = None
|
||||
if "usage" in chunk:
|
||||
usage_chunk = chunk["usage"]
|
||||
elif isinstance(chunk, ModelResponse) and hasattr(chunk, "_hidden_params"):
|
||||
elif (
|
||||
isinstance(chunk, ModelResponse)
|
||||
or isinstance(chunk, ModelResponseStream)
|
||||
) and hasattr(chunk, "_hidden_params"):
|
||||
usage_chunk = chunk._hidden_params.get("usage", None)
|
||||
|
||||
if usage_chunk is not None:
|
||||
usage_chunk_dict = self._usage_chunk_calculation_helper(usage_chunk)
|
||||
if (
|
||||
|
|
|
@ -898,6 +898,8 @@ class CustomStreamWrapper:
|
|||
return model_response
|
||||
|
||||
# Default - return StopIteration
|
||||
if hasattr(model_response, "usage"):
|
||||
self.chunks.append(model_response)
|
||||
raise StopIteration
|
||||
# flush any remaining holding chunk
|
||||
if len(self.holding_chunk) > 0:
|
||||
|
@ -1470,6 +1472,24 @@ class CustomStreamWrapper:
|
|||
"""
|
||||
self.logging_loop = loop
|
||||
|
||||
def cache_streaming_response(self, processed_chunk, cache_hit: bool):
|
||||
"""
|
||||
Caches the streaming response
|
||||
"""
|
||||
if not cache_hit and self.logging_obj._llm_caching_handler is not None:
|
||||
self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache(
|
||||
processed_chunk
|
||||
)
|
||||
|
||||
async def async_cache_streaming_response(self, processed_chunk, cache_hit: bool):
|
||||
"""
|
||||
Caches the streaming response
|
||||
"""
|
||||
if not cache_hit and self.logging_obj._llm_caching_handler is not None:
|
||||
await self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||
processed_chunk
|
||||
)
|
||||
|
||||
def run_success_logging_and_cache_storage(self, processed_chunk, cache_hit: bool):
|
||||
"""
|
||||
Runs success logging in a thread and adds the response to the cache
|
||||
|
@ -1501,12 +1521,6 @@ class CustomStreamWrapper:
|
|||
## SYNC LOGGING
|
||||
self.logging_obj.success_handler(processed_chunk, None, None, cache_hit)
|
||||
|
||||
## Sync store in cache
|
||||
if self.logging_obj._llm_caching_handler is not None:
|
||||
self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache(
|
||||
processed_chunk
|
||||
)
|
||||
|
||||
def finish_reason_handler(self):
|
||||
model_response = self.model_response_creator()
|
||||
_finish_reason = self.received_finish_reason or self.intermittent_finish_reason
|
||||
|
@ -1553,10 +1567,11 @@ class CustomStreamWrapper:
|
|||
if response is None:
|
||||
continue
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.run_success_logging_and_cache_storage,
|
||||
args=(response, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.run_success_logging_and_cache_storage,
|
||||
response,
|
||||
cache_hit,
|
||||
) # log response
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
|
@ -1600,13 +1615,27 @@ class CustomStreamWrapper:
|
|||
"usage",
|
||||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
args=(response, None, None, cache_hit),
|
||||
).start() # log response
|
||||
|
||||
self.cache_streaming_response(
|
||||
processed_chunk=complete_streaming_response.model_copy(
|
||||
deep=True
|
||||
),
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
complete_streaming_response.model_copy(deep=True),
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
)
|
||||
else:
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
response,
|
||||
None,
|
||||
None,
|
||||
cache_hit,
|
||||
)
|
||||
if self.sent_stream_usage is False and self.send_stream_usage is True:
|
||||
self.sent_stream_usage = True
|
||||
return response
|
||||
|
@ -1618,10 +1647,11 @@ class CustomStreamWrapper:
|
|||
usage = calculate_total_usage(chunks=self.chunks)
|
||||
processed_chunk._hidden_params["usage"] = usage
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.run_success_logging_and_cache_storage,
|
||||
args=(processed_chunk, cache_hit),
|
||||
).start() # log response
|
||||
executor.submit(
|
||||
self.run_success_logging_and_cache_storage,
|
||||
processed_chunk,
|
||||
cache_hit,
|
||||
) # log response
|
||||
return processed_chunk
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
|
@ -1690,13 +1720,6 @@ class CustomStreamWrapper:
|
|||
if processed_chunk is None:
|
||||
continue
|
||||
|
||||
if self.logging_obj._llm_caching_handler is not None:
|
||||
asyncio.create_task(
|
||||
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||
processed_chunk=cast(ModelResponse, processed_chunk),
|
||||
)
|
||||
)
|
||||
|
||||
choice = processed_chunk.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
|
@ -1767,6 +1790,14 @@ class CustomStreamWrapper:
|
|||
"usage",
|
||||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.async_cache_streaming_response(
|
||||
processed_chunk=complete_streaming_response.model_copy(
|
||||
deep=True
|
||||
),
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
)
|
||||
if self.sent_stream_usage is False and self.send_stream_usage is True:
|
||||
self.sent_stream_usage = True
|
||||
return response
|
||||
|
|
|
@ -29,6 +29,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
|||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Coroutine, Iterable, Literal, Optional, Union
|
||||
from typing import Any, Coroutine, Dict, Iterable, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
@ -18,10 +18,10 @@ from ...types.llms.openai import (
|
|||
SyncCursorPage,
|
||||
Thread,
|
||||
)
|
||||
from ..base import BaseLLM
|
||||
from .common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureAssistantsAPI(BaseLLM):
|
||||
class AzureAssistantsAPI(BaseAzureLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -34,18 +34,18 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AzureOpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
azure_openai_client = AzureOpenAI(**data) # type: ignore
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name="",
|
||||
api_version=api_version,
|
||||
is_async=False,
|
||||
)
|
||||
azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_openai_client = client
|
||||
|
||||
|
@ -60,18 +60,19 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AsyncAzureOpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
azure_openai_client = AsyncAzureOpenAI(**data)
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name="",
|
||||
api_version=api_version,
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
azure_openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
azure_openai_client = client
|
||||
|
@ -89,6 +90,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AsyncCursorPage[Assistant]:
|
||||
azure_openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -98,6 +100,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await azure_openai_client.beta.assistants.list()
|
||||
|
@ -146,6 +149,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_assistants=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if aget_assistants is not None and aget_assistants is True:
|
||||
return self.async_get_assistants(
|
||||
|
@ -156,6 +160,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -165,6 +170,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
client=client,
|
||||
api_version=api_version,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = azure_openai_client.beta.assistants.list()
|
||||
|
@ -184,6 +190,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> OpenAIMessage:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -193,6 +200,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
||||
|
@ -222,6 +230,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
a_add_message: Literal[True],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Coroutine[None, None, OpenAIMessage]:
|
||||
...
|
||||
|
||||
|
@ -238,6 +247,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
a_add_message: Optional[Literal[False]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> OpenAIMessage:
|
||||
...
|
||||
|
||||
|
@ -255,6 +265,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
a_add_message: Optional[bool] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if a_add_message is not None and a_add_message is True:
|
||||
return self.a_add_message(
|
||||
|
@ -267,6 +278,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -300,6 +312,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AsyncCursorPage[OpenAIMessage]:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -309,6 +322,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||
|
@ -329,6 +343,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
aget_messages: Literal[True],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
||||
...
|
||||
|
||||
|
@ -344,6 +359,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
aget_messages: Optional[Literal[False]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
...
|
||||
|
||||
|
@ -360,6 +376,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_messages=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if aget_messages is not None and aget_messages is True:
|
||||
return self.async_get_messages(
|
||||
|
@ -371,6 +388,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -380,6 +398,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||
|
@ -399,6 +418,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Thread:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -408,6 +428,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
data = {}
|
||||
|
@ -435,6 +456,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
acreate_thread: Literal[True],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Coroutine[None, None, Thread]:
|
||||
...
|
||||
|
||||
|
@ -451,6 +473,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client: Optional[AzureOpenAI],
|
||||
acreate_thread: Optional[Literal[False]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Thread:
|
||||
...
|
||||
|
||||
|
@ -468,6 +491,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client=None,
|
||||
acreate_thread=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Here's an example:
|
||||
|
@ -490,6 +514,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
client=client,
|
||||
messages=messages,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -499,6 +524,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
data = {}
|
||||
|
@ -521,6 +547,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Thread:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -530,6 +557,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||
|
@ -550,6 +578,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
aget_thread: Literal[True],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Coroutine[None, None, Thread]:
|
||||
...
|
||||
|
||||
|
@ -565,6 +594,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
aget_thread: Optional[Literal[False]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Thread:
|
||||
...
|
||||
|
||||
|
@ -581,6 +611,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_thread=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if aget_thread is not None and aget_thread is True:
|
||||
return self.async_get_thread(
|
||||
|
@ -592,6 +623,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -601,6 +633,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||
|
@ -618,7 +651,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
metadata: Optional[Dict],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
|
@ -629,6 +662,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Run:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -638,6 +672,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
|
@ -645,7 +680,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
metadata=metadata, # type: ignore
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
@ -659,12 +694,13 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
metadata: Optional[Dict],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||
data = {
|
||||
data: Dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
|
@ -684,12 +720,13 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
metadata: Optional[Dict],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
data = {
|
||||
data: Dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
|
@ -711,7 +748,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
metadata: Optional[Dict],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
|
@ -733,7 +770,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
metadata: Optional[Dict],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
|
@ -756,7 +793,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
metadata: Optional[Dict],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
|
@ -769,6 +806,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread is True:
|
||||
if stream is not None and stream is True:
|
||||
|
@ -780,6 +818,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return self.async_run_thread_stream(
|
||||
client=azure_client,
|
||||
|
@ -791,13 +830,14 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
metadata=metadata, # type: ignore
|
||||
model=model,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
|
@ -808,6 +848,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -817,6 +858,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
if stream is not None and stream is True:
|
||||
|
@ -830,6 +872,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
|
@ -837,7 +880,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
metadata=metadata, # type: ignore
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
@ -855,6 +898,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
create_assistant_data: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Assistant:
|
||||
azure_openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -864,6 +908,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await azure_openai_client.beta.assistants.create(
|
||||
|
@ -882,6 +927,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
create_assistant_data: dict,
|
||||
client=None,
|
||||
async_create_assistants=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if async_create_assistants is not None and async_create_assistants is True:
|
||||
return self.async_create_assistants(
|
||||
|
@ -893,6 +939,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
client=client,
|
||||
create_assistant_data=create_assistant_data,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -902,6 +949,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
|
||||
|
@ -918,6 +966,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
assistant_id: str,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -927,6 +976,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await azure_openai_client.beta.assistants.delete(
|
||||
|
@ -945,6 +995,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id: str,
|
||||
async_delete_assistants: Optional[bool] = None,
|
||||
client=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if async_delete_assistants is not None and async_delete_assistants is True:
|
||||
return self.async_delete_assistant(
|
||||
|
@ -956,6 +1007,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
client=client,
|
||||
assistant_id=assistant_id,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -965,6 +1017,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
import uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
||||
from litellm.types.utils import FileTypes
|
||||
from litellm.utils import TranscriptionResponse, convert_to_model_response_object
|
||||
|
||||
from .azure import (
|
||||
AzureChatCompletion,
|
||||
get_azure_ad_token_from_oidc,
|
||||
select_azure_base_url_or_endpoint,
|
||||
from litellm.utils import (
|
||||
TranscriptionResponse,
|
||||
convert_to_model_response_object,
|
||||
extract_duration_from_srt_or_vtt,
|
||||
)
|
||||
|
||||
from .azure import AzureChatCompletion
|
||||
from .common_utils import AzureOpenAIError
|
||||
|
||||
|
||||
class AzureAudioTranscription(AzureChatCompletion):
|
||||
def audio_transcriptions(
|
||||
|
@ -32,32 +32,12 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
client=None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
atranscription: bool = False,
|
||||
) -> TranscriptionResponse:
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if max_retries is not None:
|
||||
azure_client_params["max_retries"] = max_retries
|
||||
|
||||
if atranscription is True:
|
||||
return self.async_audio_transcriptions( # type: ignore
|
||||
return self.async_audio_transcriptions(
|
||||
audio_file=audio_file,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
|
@ -65,14 +45,26 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
azure_client_params=azure_client_params,
|
||||
max_retries=max_retries,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=False,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(http_client=litellm.client_session, **azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -109,25 +101,34 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
async def async_audio_transcriptions(
|
||||
self,
|
||||
audio_file: FileTypes,
|
||||
model: str,
|
||||
data: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
azure_client_params: dict,
|
||||
logging_obj: Any,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
):
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> TranscriptionResponse:
|
||||
response = None
|
||||
try:
|
||||
if client is None:
|
||||
async_azure_client = AsyncAzureOpenAI(
|
||||
**azure_client_params,
|
||||
http_client=litellm.aclient_session,
|
||||
async_azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(async_azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="async_azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
else:
|
||||
async_azure_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -156,6 +157,8 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
stringified_response = response.model_dump()
|
||||
else:
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
duration = extract_duration_from_srt_or_vtt(response)
|
||||
stringified_response["duration"] = duration
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -178,7 +181,12 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
model_response_object=model_response,
|
||||
hidden_params=hidden_params,
|
||||
response_type="audio_transcription",
|
||||
) # type: ignore
|
||||
)
|
||||
if not isinstance(response, TranscriptionResponse):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="response is not an instance of TranscriptionResponse",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.constants import DEFAULT_MAX_RETRIES
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
|
@ -25,15 +24,18 @@ from litellm.types.utils import (
|
|||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
get_secret,
|
||||
modify_url,
|
||||
)
|
||||
|
||||
from ...types.llms.openai import HttpxBinaryResponseContent
|
||||
from ..base import BaseLLM
|
||||
from .common_utils import AzureOpenAIError, process_azure_headers
|
||||
|
||||
azure_ad_cache = DualCache()
|
||||
from .common_utils import (
|
||||
AzureOpenAIError,
|
||||
BaseAzureLLM,
|
||||
get_azure_ad_token_from_oidc,
|
||||
process_azure_headers,
|
||||
select_azure_base_url_or_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIAssistantsAPIConfig:
|
||||
|
@ -98,93 +100,6 @@ class AzureOpenAIAssistantsAPIConfig:
|
|||
return optional_params
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||
if azure_endpoint is not None:
|
||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
||||
if "/openai/deployments" in azure_endpoint:
|
||||
# this is base_url, not an azure_endpoint
|
||||
azure_client_params["base_url"] = azure_endpoint
|
||||
azure_client_params.pop("azure_endpoint")
|
||||
|
||||
return azure_client_params
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||
azure_authority_host = os.getenv(
|
||||
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||
)
|
||||
|
||||
if azure_client_id is None or azure_tenant_id is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
azure_ad_token_cache_key = json.dumps(
|
||||
{
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_tenant_id": azure_tenant_id,
|
||||
"azure_authority_host": azure_authority_host,
|
||||
"oidc_token": oidc_token,
|
||||
}
|
||||
)
|
||||
|
||||
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||
if azure_ad_token_access_token is not None:
|
||||
return azure_ad_token_access_token
|
||||
|
||||
client = litellm.module_level_client
|
||||
req_token = client.post(
|
||||
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": "https://cognitiveservices.azure.com/.default",
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
azure_ad_token_json = req_token.json()
|
||||
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
|
||||
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
|
||||
|
||||
if azure_ad_token_access_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token access_token not returned"
|
||||
)
|
||||
|
||||
if azure_ad_token_expires_in is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token expires_in not returned"
|
||||
)
|
||||
|
||||
azure_ad_cache.set_cache(
|
||||
key=azure_ad_token_cache_key,
|
||||
value=azure_ad_token_access_token,
|
||||
ttl=azure_ad_token_expires_in,
|
||||
)
|
||||
|
||||
return azure_ad_token_access_token
|
||||
|
||||
|
||||
def _check_dynamic_azure_params(
|
||||
azure_client_params: dict,
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
|
||||
|
@ -206,7 +121,7 @@ def _check_dynamic_azure_params(
|
|||
return False
|
||||
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -226,52 +141,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
|
||||
return headers
|
||||
|
||||
def _get_sync_azure_client(
|
||||
self,
|
||||
api_version: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
azure_ad_token_provider: Optional[Callable],
|
||||
model: str,
|
||||
max_retries: int,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
client: Optional[Any],
|
||||
client_type: Literal["sync", "async"],
|
||||
):
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params: Dict[str, Any] = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
if client is None:
|
||||
if client_type == "sync":
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
elif client_type == "async":
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(azure_client._custom_query, dict):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
|
||||
return azure_client
|
||||
|
||||
def make_sync_azure_openai_chat_completion_request(
|
||||
self,
|
||||
azure_client: AzureOpenAI,
|
||||
|
@ -294,11 +163,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@track_llm_api_timing()
|
||||
async def make_azure_openai_chat_completion_request(
|
||||
self,
|
||||
azure_client: AsyncAzureOpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
|
@ -360,37 +231,18 @@ class AzureChatCompletion(BaseLLM):
|
|||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||
### if so - set the model as part of the base url
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
## build base url - assume api base includes resource name
|
||||
if client is None:
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"base_url": f"{api_base}",
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(
|
||||
azure_ad_token
|
||||
)
|
||||
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = (
|
||||
azure_ad_token_provider
|
||||
)
|
||||
|
||||
if acompletion is True:
|
||||
client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
client = AzureOpenAI(**azure_client_params)
|
||||
client = self._init_azure_client_for_cloudflare_ai_gateway(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
)
|
||||
|
||||
data = {"model": None, "messages": messages, **optional_params}
|
||||
else:
|
||||
|
@ -417,6 +269,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(
|
||||
|
@ -434,6 +287,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
logging_obj=logging_obj,
|
||||
max_retries=max_retries,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
return self.streaming(
|
||||
|
@ -449,6 +303,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -470,43 +325,15 @@ class AzureChatCompletion(BaseLLM):
|
|||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
client=client,
|
||||
_is_async=False,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = (
|
||||
azure_ad_token_provider
|
||||
)
|
||||
|
||||
if (
|
||||
client is None
|
||||
or not isinstance(client, AzureOpenAI)
|
||||
or dynamic_params
|
||||
):
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(
|
||||
azure_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault(
|
||||
"api-version", api_version
|
||||
)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
|
@ -566,36 +393,22 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
convert_tool_call_to_json_mode: Optional[bool] = None,
|
||||
client=None, # this is the AsyncAzureOpenAI
|
||||
litellm_params: Optional[dict] = {},
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.aclient_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
# setting Azure client
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
client=client,
|
||||
_is_async=True,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
@ -615,6 +428,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client=azure_client,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
logging_obj.model_call_details["response_headers"] = headers
|
||||
|
||||
|
@ -680,6 +494,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
client=None,
|
||||
litellm_params: Optional[dict] = {},
|
||||
):
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
|
@ -702,10 +517,20 @@ class AzureChatCompletion(BaseLLM):
|
|||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
client=client,
|
||||
_is_async=False,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
@ -747,32 +572,21 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
client=None,
|
||||
litellm_params: Optional[dict] = {},
|
||||
):
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.aclient_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
client=client,
|
||||
_is_async=True,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
@ -792,6 +606,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client=azure_client,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
logging_obj.model_call_details["response_headers"] = headers
|
||||
|
||||
|
@ -822,21 +637,36 @@ class AzureChatCompletion(BaseLLM):
|
|||
|
||||
async def aembedding(
|
||||
self,
|
||||
model: str,
|
||||
data: dict,
|
||||
model_response: EmbeddingResponse,
|
||||
azure_client_params: dict,
|
||||
input: list,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
timeout=None,
|
||||
):
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
litellm_params: Optional[dict] = {},
|
||||
) -> EmbeddingResponse:
|
||||
response = None
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
openai_aclient = client
|
||||
|
||||
openai_aclient = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(openai_aclient, AsyncAzureOpenAI):
|
||||
raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
|
||||
|
||||
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
|
@ -850,13 +680,19 @@ class AzureChatCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(
|
||||
embedding_response = convert_to_model_response_object(
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
hidden_params={"headers": headers},
|
||||
_response_headers=process_azure_headers(headers),
|
||||
response_type="embedding",
|
||||
)
|
||||
if not isinstance(embedding_response, EmbeddingResponse):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="embedding_response is not an instance of EmbeddingResponse",
|
||||
)
|
||||
return embedding_response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -884,7 +720,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
client=None,
|
||||
aembedding=None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[EmbeddingResponse, Coroutine[Any, Any, EmbeddingResponse]]:
|
||||
if headers:
|
||||
optional_params["extra_headers"] = headers
|
||||
if self._client_session is None:
|
||||
|
@ -893,35 +730,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
data = {"model": model, "input": input, **optional_params}
|
||||
if max_retries is None:
|
||||
max_retries = litellm.DEFAULT_MAX_RETRIES
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if aembedding:
|
||||
azure_client_params["http_client"] = litellm.aclient_session
|
||||
else:
|
||||
azure_client_params["http_client"] = litellm.client_session
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -933,20 +741,33 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
if aembedding is True:
|
||||
return self.aembedding( # type: ignore
|
||||
return self.aembedding(
|
||||
data=data,
|
||||
input=input,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
azure_client_params=azure_client_params,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=False,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
raw_response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
|
||||
headers = dict(raw_response.headers)
|
||||
|
@ -1281,6 +1102,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> ImageResponse:
|
||||
try:
|
||||
if model and len(model) > 0:
|
||||
|
@ -1305,25 +1127,14 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params: Dict[str, Any] = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
model_name=model or "",
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
is_async=False,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
if aimg_generation is True:
|
||||
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
|
||||
|
||||
|
@ -1386,6 +1197,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
aspeech: Optional[bool] = None,
|
||||
client=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
|
@ -1404,19 +1216,17 @@ class AzureChatCompletion(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
) # type: ignore
|
||||
|
||||
azure_client: AzureOpenAI = self._get_sync_azure_client(
|
||||
azure_client: AzureOpenAI = self.get_azure_openai_client(
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
model=model,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
_is_async=False,
|
||||
client=client,
|
||||
client_type="sync",
|
||||
litellm_params=litellm_params,
|
||||
) # type: ignore
|
||||
|
||||
response = azure_client.audio.speech.create(
|
||||
|
@ -1441,19 +1251,17 @@ class AzureChatCompletion(BaseLLM):
|
|||
max_retries: int,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
client=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
|
||||
azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
model=model,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
client_type="async",
|
||||
litellm_params=litellm_params,
|
||||
) # type: ignore
|
||||
|
||||
azure_response = await azure_client.audio.speech.create(
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import Any, Coroutine, Optional, Union, cast
|
|||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
|
@ -16,8 +15,10 @@ from litellm.types.llms.openai import (
|
|||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
class AzureBatchesAPI:
|
||||
|
||||
class AzureBatchesAPI(BaseAzureLLM):
|
||||
"""
|
||||
Azure methods to support for batches
|
||||
- create_batch()
|
||||
|
@ -29,38 +30,6 @@ class AzureBatchesAPI:
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_azure_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
received_args = locals()
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client" or k == "_is_async":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
if "api_version" not in data:
|
||||
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**data)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
async def acreate_batch(
|
||||
self,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
|
@ -79,16 +48,16 @@ class AzureBatchesAPI:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
)
|
||||
if azure_client is None:
|
||||
|
@ -125,16 +94,16 @@ class AzureBatchesAPI:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
)
|
||||
if azure_client is None:
|
||||
|
@ -173,16 +142,16 @@ class AzureBatchesAPI:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
)
|
||||
if azure_client is None:
|
||||
|
@ -212,16 +181,16 @@ class AzureBatchesAPI:
|
|||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
)
|
||||
if azure_client is None:
|
||||
|
|
|
@ -99,6 +99,8 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"prediction",
|
||||
"modalities",
|
||||
"audio",
|
||||
]
|
||||
|
||||
def _is_response_format_supported_model(self, model: str) -> bool:
|
||||
|
|
|
@ -4,50 +4,69 @@ Handler file for calls to Azure OpenAI's o1/o3 family of models
|
|||
Written separately to handle faking streaming for o1 and o3 models.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ...openai.openai import OpenAIChatCompletion
|
||||
from ..common_utils import get_azure_openai_client
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion):
|
||||
def _get_openai_client(
|
||||
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
|
||||
def completion(
|
||||
self,
|
||||
is_async: bool,
|
||||
model_response: ModelResponse,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj: Any,
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
max_retries: Optional[int] = 2,
|
||||
dynamic_params: Optional[bool] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
custom_prompt_dict: dict = {},
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
) -> Optional[
|
||||
Union[
|
||||
OpenAI,
|
||||
AsyncOpenAI,
|
||||
AzureOpenAI,
|
||||
AsyncAzureOpenAI,
|
||||
]
|
||||
]:
|
||||
|
||||
# Override to use Azure-specific client initialization
|
||||
if not isinstance(client, AzureOpenAI) and not isinstance(
|
||||
client, AsyncAzureOpenAI
|
||||
):
|
||||
client = None
|
||||
|
||||
return get_azure_openai_client(
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
client = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=is_async,
|
||||
_is_async=acompletion,
|
||||
)
|
||||
return super().completion(
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
dynamic_params=dynamic_params,
|
||||
azure_ad_token=azure_ad_token,
|
||||
acompletion=acompletion,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client,
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
|
|
@ -1,13 +1,22 @@
|
|||
from typing import Callable, Optional, Union
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.openai.common_utils import BaseOpenAILLM
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
azure_ad_cache = DualCache()
|
||||
|
||||
|
||||
class AzureOpenAIError(BaseLLMException):
|
||||
def __init__(
|
||||
|
@ -29,39 +38,6 @@ class AzureOpenAIError(BaseLLMException):
|
|||
)
|
||||
|
||||
|
||||
def get_azure_openai_client(
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
received_args = locals()
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client" or k == "_is_async":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
if "api_version" not in data:
|
||||
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**data)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
|
||||
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||
openai_headers = {}
|
||||
if "x-ratelimit-limit-requests" in headers:
|
||||
|
@ -180,3 +156,271 @@ def get_azure_ad_token_from_username_password(
|
|||
verbose_logger.debug("token_provider %s", token_provider)
|
||||
|
||||
return token_provider
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||
azure_authority_host = os.getenv(
|
||||
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||
)
|
||||
|
||||
if azure_client_id is None or azure_tenant_id is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret_str(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
azure_ad_token_cache_key = json.dumps(
|
||||
{
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_tenant_id": azure_tenant_id,
|
||||
"azure_authority_host": azure_authority_host,
|
||||
"oidc_token": oidc_token,
|
||||
}
|
||||
)
|
||||
|
||||
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||
if azure_ad_token_access_token is not None:
|
||||
return azure_ad_token_access_token
|
||||
|
||||
client = litellm.module_level_client
|
||||
req_token = client.post(
|
||||
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": "https://cognitiveservices.azure.com/.default",
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
azure_ad_token_json = req_token.json()
|
||||
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
|
||||
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
|
||||
|
||||
if azure_ad_token_access_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token access_token not returned"
|
||||
)
|
||||
|
||||
if azure_ad_token_expires_in is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token expires_in not returned"
|
||||
)
|
||||
|
||||
azure_ad_cache.set_cache(
|
||||
key=azure_ad_token_cache_key,
|
||||
value=azure_ad_token_access_token,
|
||||
ttl=azure_ad_token_expires_in,
|
||||
)
|
||||
|
||||
return azure_ad_token_access_token
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||
if azure_endpoint is not None:
|
||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
||||
if "/openai/deployments" in azure_endpoint:
|
||||
# this is base_url, not an azure_endpoint
|
||||
azure_client_params["base_url"] = azure_endpoint
|
||||
azure_client_params.pop("azure_endpoint")
|
||||
|
||||
return azure_client_params
|
||||
|
||||
|
||||
class BaseAzureLLM(BaseOpenAILLM):
|
||||
def get_azure_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
_is_async: bool = False,
|
||||
model: Optional[str] = None,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
client_initialization_params: dict = locals()
|
||||
if client is None:
|
||||
cached_client = self.get_cached_openai_client(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type="azure",
|
||||
)
|
||||
if cached_client:
|
||||
if isinstance(cached_client, AzureOpenAI) or isinstance(
|
||||
cached_client, AsyncAzureOpenAI
|
||||
):
|
||||
return cached_client
|
||||
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name=model,
|
||||
api_version=api_version,
|
||||
is_async=_is_async,
|
||||
)
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
if api_version is not None and isinstance(
|
||||
openai_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
openai_client._custom_query.setdefault("api-version", api_version)
|
||||
|
||||
# save client in-memory cache
|
||||
self.set_cached_openai_client(
|
||||
openai_client=openai_client,
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type="azure",
|
||||
)
|
||||
return openai_client
|
||||
|
||||
def initialize_azure_sdk_client(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
model_name: Optional[str],
|
||||
api_version: Optional[str],
|
||||
is_async: bool,
|
||||
) -> dict:
|
||||
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
# If we have api_key, then we have higher priority
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
tenant_id = litellm_params.get("tenant_id")
|
||||
client_id = litellm_params.get("client_id")
|
||||
client_secret = litellm_params.get("client_secret")
|
||||
azure_username = litellm_params.get("azure_username")
|
||||
azure_password = litellm_params.get("azure_password")
|
||||
max_retries = litellm_params.get("max_retries")
|
||||
timeout = litellm_params.get("timeout")
|
||||
if not api_key and tenant_id and client_id and client_secret:
|
||||
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
if azure_username and azure_password and client_id:
|
||||
azure_ad_token_provider = get_azure_ad_token_from_username_password(
|
||||
azure_username=azure_username,
|
||||
azure_password=azure_password,
|
||||
client_id=client_id,
|
||||
)
|
||||
|
||||
if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
elif (
|
||||
not api_key
|
||||
and azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider()
|
||||
except ValueError:
|
||||
verbose_logger.debug("Azure AD Token Provider could not be used.")
|
||||
if api_version is None:
|
||||
api_version = os.getenv(
|
||||
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
_api_key = api_key
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
# only show first 5 chars of api_key
|
||||
_api_key = _api_key[:8] + "*" * 15
|
||||
verbose_logger.debug(
|
||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
|
||||
)
|
||||
azure_client_params = {
|
||||
"api_key": api_key,
|
||||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
# init http client + SSL Verification settings
|
||||
if is_async is True:
|
||||
azure_client_params["http_client"] = self._get_async_http_client()
|
||||
else:
|
||||
azure_client_params["http_client"] = self._get_sync_http_client()
|
||||
|
||||
if max_retries is not None:
|
||||
azure_client_params["max_retries"] = max_retries
|
||||
if timeout is not None:
|
||||
azure_client_params["timeout"] = timeout
|
||||
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
|
||||
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
|
||||
return azure_client_params
|
||||
|
||||
def _init_azure_client_for_cloudflare_ai_gateway(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
api_version: str,
|
||||
max_retries: int,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
api_key: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
azure_ad_token_provider: Optional[Callable[[], str]],
|
||||
acompletion: bool,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
) -> Union[AzureOpenAI, AsyncAzureOpenAI]:
|
||||
## build base url - assume api base includes resource name
|
||||
if client is None:
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
|
||||
azure_client_params: Dict[str, Any] = {
|
||||
"api_version": api_version,
|
||||
"base_url": f"{api_base}",
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
if acompletion is True:
|
||||
client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
return client
|
||||
|
|
|
@ -2,30 +2,16 @@ from typing import Any, Callable, Optional
|
|||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, TextCompletionResponse
|
||||
|
||||
from ...base import BaseLLM
|
||||
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
||||
from ..common_utils import AzureOpenAIError
|
||||
from ..common_utils import AzureOpenAIError, BaseAzureLLM
|
||||
|
||||
openai_text_completion_config = OpenAITextCompletionConfig()
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||
if azure_endpoint is not None:
|
||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
||||
if "/openai/deployments" in azure_endpoint:
|
||||
# this is base_url, not an azure_endpoint
|
||||
azure_client_params["base_url"] = azure_endpoint
|
||||
azure_client_params.pop("azure_endpoint")
|
||||
|
||||
return azure_client_params
|
||||
|
||||
|
||||
class AzureTextCompletion(BaseLLM):
|
||||
class AzureTextCompletion(BaseAzureLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -60,7 +46,6 @@ class AzureTextCompletion(BaseLLM):
|
|||
headers: Optional[dict] = None,
|
||||
client=None,
|
||||
):
|
||||
super().completion()
|
||||
try:
|
||||
if model is None or messages is None:
|
||||
raise AzureOpenAIError(
|
||||
|
@ -76,27 +61,18 @@ class AzureTextCompletion(BaseLLM):
|
|||
### if so - set the model as part of the base url
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
## build base url - assume api base includes resource name
|
||||
if client is None:
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"base_url": f"{api_base}",
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if acompletion is True:
|
||||
client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
client = AzureOpenAI(**azure_client_params)
|
||||
client = self._init_azure_client_for_cloudflare_ai_gateway(
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
|
||||
data = {"model": None, "prompt": prompt, **optional_params}
|
||||
else:
|
||||
|
@ -118,6 +94,7 @@ class AzureTextCompletion(BaseLLM):
|
|||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(
|
||||
|
@ -132,6 +109,7 @@ class AzureTextCompletion(BaseLLM):
|
|||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
max_retries=max_retries,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
return self.streaming(
|
||||
|
@ -165,33 +143,21 @@ class AzureTextCompletion(BaseLLM):
|
|||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
_is_async=False,
|
||||
model=model,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(
|
||||
azure_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault(
|
||||
"api-version", api_version
|
||||
)
|
||||
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
raw_response = azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
|
@ -240,36 +206,27 @@ class AzureTextCompletion(BaseLLM):
|
|||
max_retries: int,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None, # this is the AsyncAzureOpenAI
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
# setting Azure client
|
||||
if client is None:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(
|
||||
azure_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
|
@ -312,6 +269,7 @@ class AzureTextCompletion(BaseLLM):
|
|||
timeout: Any,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
|
@ -319,28 +277,21 @@ class AzureTextCompletion(BaseLLM):
|
|||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=False,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(azure_client._custom_query, dict):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
|
@ -375,33 +326,24 @@ class AzureTextCompletion(BaseLLM):
|
|||
timeout: Any,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": data.pop("max_retries", 2),
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(
|
||||
azure_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
|
|
|
@ -5,13 +5,12 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
|
|||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.types.llms.openai import *
|
||||
|
||||
from ..common_utils import get_azure_openai_client
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureOpenAIFilesAPI(BaseLLM):
|
||||
class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support for batches
|
||||
- create_file()
|
||||
|
@ -45,14 +44,15 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
||||
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
|
@ -91,17 +91,16 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
organization=None,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
|
@ -144,14 +143,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=None,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
|
@ -197,14 +195,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
organization: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
|
@ -252,14 +249,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
purpose: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=None, # openai param
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
|
|
|
@ -3,11 +3,11 @@ from typing import Optional, Union
|
|||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from litellm.llms.azure.files.handler import get_azure_openai_client
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
||||
|
||||
|
||||
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
||||
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
|
||||
"""
|
||||
|
@ -24,6 +24,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
|||
] = None,
|
||||
_is_async: bool = False,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Optional[
|
||||
Union[
|
||||
OpenAI,
|
||||
|
@ -36,12 +37,10 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
|||
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
||||
client = None
|
||||
|
||||
return get_azure_openai_client(
|
||||
return self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
|
|
|
@ -16,10 +16,23 @@ from litellm.llms.openai.openai import OpenAIConfig
|
|||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, ProviderField
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
from litellm.utils import _add_path_to_api_base, supports_tool_choice
|
||||
|
||||
|
||||
class AzureAIStudioConfig(OpenAIConfig):
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
model_supports_tool_choice = True # azure ai supports this by default
|
||||
if not supports_tool_choice(model=f"azure_ai/{model}"):
|
||||
model_supports_tool_choice = False
|
||||
supported_params = super().get_supported_openai_params(model)
|
||||
if not model_supports_tool_choice:
|
||||
filtered_supported_params = []
|
||||
for param in supported_params:
|
||||
if param != "tool_choice":
|
||||
filtered_supported_params.append(param)
|
||||
return filtered_supported_params
|
||||
return supported_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
|
@ -54,6 +67,7 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
@ -79,12 +93,14 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], optional_params.get("api_version"))
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Check if 'api-version' is already present
|
||||
if "api-version" not in original_url.params and api_version:
|
||||
# Add api_version to optional_params
|
||||
original_url.params["api-version"] = api_version
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if "services.ai.azure.com" in api_base:
|
||||
|
@ -96,8 +112,7 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
api_base=api_base, ending_path="/chat/completions"
|
||||
)
|
||||
|
||||
# Convert optional_params to query parameters
|
||||
query_params = original_url.params
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
|
|