forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_fix_track_cost_callback
This commit is contained in:
commit
7973665219
250 changed files with 17396 additions and 4808 deletions
|
@ -47,7 +47,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.52.0
|
||||
pip install openai==1.54.0
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -103,7 +103,7 @@ jobs:
|
|||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/local_testing --cov=litellm --cov-report=xml -x --junitxml=test-results/junit.xml --durations=5 -k "not test_python_38.py and not router and not assistants"
|
||||
python -m pytest -vv tests/local_testing --cov=litellm --cov-report=xml -x --junitxml=test-results/junit.xml --durations=5 -k "not test_python_38.py and not router and not assistants and not langfuse and not caching and not cache"
|
||||
no_output_timeout: 120m
|
||||
- run:
|
||||
name: Rename the coverage files
|
||||
|
@ -119,6 +119,204 @@ jobs:
|
|||
paths:
|
||||
- local_testing_coverage.xml
|
||||
- local_testing_coverage
|
||||
langfuse_logging_unit_tests:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
username: ${DOCKERHUB_USERNAME}
|
||||
password: ${DOCKERHUB_PASSWORD}
|
||||
working_directory: ~/project
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
|
||||
- run:
|
||||
name: Show git commit hash
|
||||
command: |
|
||||
echo "Git commit hash: $CIRCLE_SHA1"
|
||||
|
||||
- restore_cache:
|
||||
keys:
|
||||
- v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r .circleci/requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "pytest-cov==5.0.0"
|
||||
pip install mypy
|
||||
pip install "google-generativeai==0.3.2"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
pip install pyarrow
|
||||
pip install "boto3==1.34.34"
|
||||
pip install "aioboto3==12.3.0"
|
||||
pip install langchain
|
||||
pip install lunary==0.2.5
|
||||
pip install "azure-identity==1.16.1"
|
||||
pip install "langfuse==2.45.0"
|
||||
pip install "logfire==0.29.0"
|
||||
pip install numpydoc
|
||||
pip install traceloop-sdk==0.21.1
|
||||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.54.0
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
pip install "respx==0.21.1"
|
||||
pip install fastapi
|
||||
pip install "gunicorn==21.2.0"
|
||||
pip install "anyio==4.2.0"
|
||||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "apscheduler==3.10.4"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install argon2-cffi
|
||||
pip install "pytest-mock==3.12.0"
|
||||
pip install python-multipart
|
||||
pip install google-cloud-aiplatform
|
||||
pip install prometheus-client==0.20.0
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "diskcache==5.6.1"
|
||||
pip install "Pillow==10.3.0"
|
||||
pip install "jsonschema==4.22.0"
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
key: v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
|
||||
- run:
|
||||
name: Run prisma ./docker/entrypoint.sh
|
||||
command: |
|
||||
set +e
|
||||
chmod +x docker/entrypoint.sh
|
||||
./docker/entrypoint.sh
|
||||
set -e
|
||||
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/local_testing --cov=litellm --cov-report=xml -x --junitxml=test-results/junit.xml --durations=5 -k "langfuse"
|
||||
no_output_timeout: 120m
|
||||
- run:
|
||||
name: Rename the coverage files
|
||||
command: |
|
||||
mv coverage.xml langfuse_coverage.xml
|
||||
mv .coverage langfuse_coverage
|
||||
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- persist_to_workspace:
|
||||
root: .
|
||||
paths:
|
||||
- langfuse_coverage.xml
|
||||
- langfuse_coverage
|
||||
caching_unit_tests:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
username: ${DOCKERHUB_USERNAME}
|
||||
password: ${DOCKERHUB_PASSWORD}
|
||||
working_directory: ~/project
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
|
||||
- run:
|
||||
name: Show git commit hash
|
||||
command: |
|
||||
echo "Git commit hash: $CIRCLE_SHA1"
|
||||
|
||||
- restore_cache:
|
||||
keys:
|
||||
- v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r .circleci/requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "pytest-cov==5.0.0"
|
||||
pip install mypy
|
||||
pip install "google-generativeai==0.3.2"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
pip install pyarrow
|
||||
pip install "boto3==1.34.34"
|
||||
pip install "aioboto3==12.3.0"
|
||||
pip install langchain
|
||||
pip install lunary==0.2.5
|
||||
pip install "azure-identity==1.16.1"
|
||||
pip install "langfuse==2.45.0"
|
||||
pip install "logfire==0.29.0"
|
||||
pip install numpydoc
|
||||
pip install traceloop-sdk==0.21.1
|
||||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.54.0
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
pip install "respx==0.21.1"
|
||||
pip install fastapi
|
||||
pip install "gunicorn==21.2.0"
|
||||
pip install "anyio==4.2.0"
|
||||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "apscheduler==3.10.4"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install argon2-cffi
|
||||
pip install "pytest-mock==3.12.0"
|
||||
pip install python-multipart
|
||||
pip install google-cloud-aiplatform
|
||||
pip install prometheus-client==0.20.0
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "diskcache==5.6.1"
|
||||
pip install "Pillow==10.3.0"
|
||||
pip install "jsonschema==4.22.0"
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
key: v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
|
||||
- run:
|
||||
name: Run prisma ./docker/entrypoint.sh
|
||||
command: |
|
||||
set +e
|
||||
chmod +x docker/entrypoint.sh
|
||||
./docker/entrypoint.sh
|
||||
set -e
|
||||
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/local_testing --cov=litellm --cov-report=xml -x --junitxml=test-results/junit.xml --durations=5 -k "caching or cache"
|
||||
no_output_timeout: 120m
|
||||
- run:
|
||||
name: Rename the coverage files
|
||||
command: |
|
||||
mv coverage.xml caching_coverage.xml
|
||||
mv .coverage caching_coverage
|
||||
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- persist_to_workspace:
|
||||
root: .
|
||||
paths:
|
||||
- caching_coverage.xml
|
||||
- caching_coverage
|
||||
auth_ui_unit_tests:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
|
@ -215,6 +413,105 @@ jobs:
|
|||
paths:
|
||||
- litellm_router_coverage.xml
|
||||
- litellm_router_coverage
|
||||
litellm_proxy_unit_testing: # Runs all tests with the "proxy", "key", "jwt" filenames
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
username: ${DOCKERHUB_USERNAME}
|
||||
password: ${DOCKERHUB_PASSWORD}
|
||||
working_directory: ~/project
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
|
||||
- run:
|
||||
name: Show git commit hash
|
||||
command: |
|
||||
echo "Git commit hash: $CIRCLE_SHA1"
|
||||
|
||||
- restore_cache:
|
||||
keys:
|
||||
- v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r .circleci/requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "pytest-cov==5.0.0"
|
||||
pip install mypy
|
||||
pip install "google-generativeai==0.3.2"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
pip install pyarrow
|
||||
pip install "boto3==1.34.34"
|
||||
pip install "aioboto3==12.3.0"
|
||||
pip install langchain
|
||||
pip install lunary==0.2.5
|
||||
pip install "azure-identity==1.16.1"
|
||||
pip install "langfuse==2.45.0"
|
||||
pip install "logfire==0.29.0"
|
||||
pip install numpydoc
|
||||
pip install traceloop-sdk==0.21.1
|
||||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.54.0
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
pip install "respx==0.21.1"
|
||||
pip install fastapi
|
||||
pip install "gunicorn==21.2.0"
|
||||
pip install "anyio==4.2.0"
|
||||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "apscheduler==3.10.4"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install argon2-cffi
|
||||
pip install "pytest-mock==3.12.0"
|
||||
pip install python-multipart
|
||||
pip install google-cloud-aiplatform
|
||||
pip install prometheus-client==0.20.0
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "diskcache==5.6.1"
|
||||
pip install "Pillow==10.3.0"
|
||||
pip install "jsonschema==4.22.0"
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
key: v1-dependencies-{{ checksum ".circleci/requirements.txt" }}
|
||||
- run:
|
||||
name: Run prisma ./docker/entrypoint.sh
|
||||
command: |
|
||||
set +e
|
||||
chmod +x docker/entrypoint.sh
|
||||
./docker/entrypoint.sh
|
||||
set -e
|
||||
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest tests/proxy_unit_tests --cov=litellm --cov-report=xml -vv -x -v --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout: 120m
|
||||
- run:
|
||||
name: Rename the coverage files
|
||||
command: |
|
||||
mv coverage.xml litellm_proxy_unit_tests_coverage.xml
|
||||
mv .coverage litellm_proxy_unit_tests_coverage
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
|
||||
- persist_to_workspace:
|
||||
root: .
|
||||
paths:
|
||||
- litellm_proxy_unit_tests_coverage.xml
|
||||
- litellm_proxy_unit_tests_coverage
|
||||
litellm_assistants_api_testing: # Runs all tests with the "assistants" keyword
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
|
@ -328,7 +625,7 @@ jobs:
|
|||
paths:
|
||||
- llm_translation_coverage.xml
|
||||
- llm_translation_coverage
|
||||
logging_testing:
|
||||
image_gen_testing:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
|
@ -349,6 +646,51 @@ jobs:
|
|||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "respx==0.21.1"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/image_gen_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout: 120m
|
||||
- run:
|
||||
name: Rename the coverage files
|
||||
command: |
|
||||
mv coverage.xml image_gen_coverage.xml
|
||||
mv .coverage image_gen_coverage
|
||||
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- persist_to_workspace:
|
||||
root: .
|
||||
paths:
|
||||
- image_gen_coverage.xml
|
||||
- image_gen_coverage
|
||||
logging_testing:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
username: ${DOCKERHUB_USERNAME}
|
||||
password: ${DOCKERHUB_PASSWORD}
|
||||
working_directory: ~/project
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-cov==5.0.0"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install pytest-mock
|
||||
pip install "respx==0.21.1"
|
||||
pip install "google-generativeai==0.3.2"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
|
@ -392,7 +734,7 @@ jobs:
|
|||
pip install click
|
||||
pip install "boto3==1.34.34"
|
||||
pip install jinja2
|
||||
pip install tokenizers
|
||||
pip install tokenizers=="0.20.0"
|
||||
pip install jsonschema
|
||||
- run:
|
||||
name: Run tests
|
||||
|
@ -425,6 +767,7 @@ jobs:
|
|||
- run: python ./tests/documentation_tests/test_general_setting_keys.py
|
||||
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
||||
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
|
||||
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
|
||||
- run: python ./tests/documentation_tests/test_env_keys.py
|
||||
- run: helm lint ./deploy/charts/litellm-helm
|
||||
|
||||
|
@ -520,7 +863,7 @@ jobs:
|
|||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "openai==1.52.0"
|
||||
pip install "openai==1.54.0 "
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Build Docker image
|
||||
|
@ -577,7 +920,7 @@ jobs:
|
|||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation
|
||||
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/image_gen_tests
|
||||
no_output_timeout: 120m
|
||||
|
||||
# Store test results
|
||||
|
@ -637,7 +980,7 @@ jobs:
|
|||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "openai==1.52.0"
|
||||
pip install "openai==1.54.0 "
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||
|
@ -664,6 +1007,7 @@ jobs:
|
|||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
|
||||
-e COHERE_API_KEY=$COHERE_API_KEY \
|
||||
-e GCS_FLUSH_INTERVAL="1" \
|
||||
--name my-app \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/custom_guardrail.py:/app/custom_guardrail.py \
|
||||
|
@ -729,7 +1073,7 @@ jobs:
|
|||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
pip install aiohttp
|
||||
pip install "openai==1.52.0"
|
||||
pip install "openai==1.54.0 "
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "pytest==7.3.1"
|
||||
|
@ -814,7 +1158,7 @@ jobs:
|
|||
python -m venv venv
|
||||
. venv/bin/activate
|
||||
pip install coverage
|
||||
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage
|
||||
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage
|
||||
coverage xml
|
||||
- codecov/upload:
|
||||
file: ./coverage.xml
|
||||
|
@ -924,7 +1268,7 @@ jobs:
|
|||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install aiohttp
|
||||
pip install "openai==1.52.0"
|
||||
pip install "openai==1.54.0 "
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "pytest==7.3.1"
|
||||
|
@ -986,6 +1330,41 @@ jobs:
|
|||
- store_test_results:
|
||||
path: test-results
|
||||
|
||||
test_bad_database_url:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
resource_class: xlarge
|
||||
working_directory: ~/project
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: |
|
||||
docker build -t myapp . -f ./docker/Dockerfile.non_root
|
||||
- run:
|
||||
name: Run Docker container with bad DATABASE_URL
|
||||
command: |
|
||||
docker run --name my-app \
|
||||
-p 4000:4000 \
|
||||
-e DATABASE_URL="postgresql://wrong:wrong@wrong:5432/wrong" \
|
||||
myapp:latest \
|
||||
--port 4000 > docker_output.log 2>&1 || true
|
||||
- run:
|
||||
name: Display Docker logs
|
||||
command: cat docker_output.log
|
||||
- run:
|
||||
name: Check for expected error
|
||||
command: |
|
||||
if grep -q "Error: P1001: Can't reach database server at" docker_output.log && \
|
||||
grep -q "httpx.ConnectError: All connection attempts failed" docker_output.log && \
|
||||
grep -q "ERROR: Application startup failed. Exiting." docker_output.log; then
|
||||
echo "Expected error found. Test passed."
|
||||
else
|
||||
echo "Expected error not found. Test failed."
|
||||
cat docker_output.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
build_and_test:
|
||||
|
@ -996,6 +1375,24 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- langfuse_logging_unit_tests:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- caching_unit_tests:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- litellm_proxy_unit_testing:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- litellm_assistants_api_testing:
|
||||
filters:
|
||||
branches:
|
||||
|
@ -1050,6 +1447,12 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- image_gen_testing:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- logging_testing:
|
||||
filters:
|
||||
branches:
|
||||
|
@ -1059,8 +1462,12 @@ workflows:
|
|||
- upload-coverage:
|
||||
requires:
|
||||
- llm_translation_testing
|
||||
- image_gen_testing
|
||||
- logging_testing
|
||||
- litellm_router_testing
|
||||
- caching_unit_tests
|
||||
- litellm_proxy_unit_testing
|
||||
- langfuse_logging_unit_tests
|
||||
- local_testing
|
||||
- litellm_assistants_api_testing
|
||||
- auth_ui_unit_tests
|
||||
|
@ -1082,18 +1489,29 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- test_bad_database_url:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- publish_to_pypi:
|
||||
requires:
|
||||
- local_testing
|
||||
- build_and_test
|
||||
- load_testing
|
||||
- test_bad_database_url
|
||||
- llm_translation_testing
|
||||
- image_gen_testing
|
||||
- logging_testing
|
||||
- litellm_router_testing
|
||||
- caching_unit_tests
|
||||
- langfuse_logging_unit_tests
|
||||
- litellm_assistants_api_testing
|
||||
- auth_ui_unit_tests
|
||||
- db_migration_disable_update_check
|
||||
- e2e_ui_testing
|
||||
- litellm_proxy_unit_testing
|
||||
- installing_litellm_on_python
|
||||
- proxy_logging_guardrails_model_info_tests
|
||||
- proxy_pass_through_endpoint_tests
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# used by CI/CD testing
|
||||
openai==1.52.0
|
||||
openai==1.54.0
|
||||
python-dotenv
|
||||
tiktoken
|
||||
importlib_metadata
|
||||
|
|
35
.github/workflows/lint.yml
vendored
35
.github/workflows/lint.yml
vendored
|
@ -1,35 +0,0 @@
|
|||
name: Lint
|
||||
|
||||
# If a pull-request is pushed then cancel all previously running jobs related
|
||||
# to that pull-request
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
|
||||
env:
|
||||
PY_COLORS: 1
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4.1.1
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install poetry
|
||||
run: pipx install poetry
|
||||
- uses: actions/setup-python@v5.0.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- run: poetry install
|
||||
- name: Run flake8 (https://flake8.pycqa.org/en/latest/)
|
||||
run: poetry run flake8
|
|
@ -4,7 +4,6 @@ python script to pre-create all views required by LiteLLM Proxy Server
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
from update_unassigned_teams import apply_db_fixes
|
||||
|
||||
# Enter your DATABASE_URL here
|
||||
|
||||
|
@ -205,7 +204,6 @@ async def check_view_exists(): # noqa: PLR0915
|
|||
|
||||
print("Last30dTopEndUsersSpend Created!") # noqa
|
||||
|
||||
await apply_db_fixes(db=db)
|
||||
return
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
from prisma import Prisma
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
async def apply_db_fixes(db: Prisma):
|
||||
"""
|
||||
Do Not Run this in production, only use it as a one-time fix
|
||||
"""
|
||||
verbose_logger.warning(
|
||||
"DO NOT run this in Production....Running update_unassigned_teams"
|
||||
)
|
||||
try:
|
||||
sql_query = """
|
||||
UPDATE "LiteLLM_SpendLogs"
|
||||
|
|
1
docs/my-website/.gitignore
vendored
1
docs/my-website/.gitignore
vendored
|
@ -18,3 +18,4 @@
|
|||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
yarn.lock
|
||||
|
|
41
docs/my-website/docs/benchmarks.md
Normal file
41
docs/my-website/docs/benchmarks.md
Normal file
|
@ -0,0 +1,41 @@
|
|||
# Benchmarks
|
||||
|
||||
Benchmarks for LiteLLM Gateway (Proxy Server)
|
||||
|
||||
Locust Settings:
|
||||
- 2500 Users
|
||||
- 100 user Ramp Up
|
||||
|
||||
|
||||
## Basic Benchmarks
|
||||
|
||||
Overhead when using a Deployed Proxy vs Direct to LLM
|
||||
- Latency overhead added by LiteLLM Proxy: 107ms
|
||||
|
||||
| Metric | Direct to Fake Endpoint | Basic Litellm Proxy |
|
||||
|--------|------------------------|---------------------|
|
||||
| RPS | 1196 | 1133.2 |
|
||||
| Median Latency (ms) | 33 | 140 |
|
||||
|
||||
|
||||
## Logging Callbacks
|
||||
|
||||
### [GCS Bucket Logging](https://docs.litellm.ai/docs/proxy/bucket)
|
||||
|
||||
Using GCS Bucket has **no impact on latency, RPS compared to Basic Litellm Proxy**
|
||||
|
||||
| Metric | Basic Litellm Proxy | LiteLLM Proxy with GCS Bucket Logging |
|
||||
|--------|------------------------|---------------------|
|
||||
| RPS | 1133.2 | 1137.3 |
|
||||
| Median Latency (ms) | 140 | 138 |
|
||||
|
||||
|
||||
### [LangSmith logging](https://docs.litellm.ai/docs/proxy/logging)
|
||||
|
||||
Using LangSmith has **no impact on latency, RPS compared to Basic Litellm Proxy**
|
||||
|
||||
| Metric | Basic Litellm Proxy | LiteLLM Proxy with LangSmith |
|
||||
|--------|------------------------|---------------------|
|
||||
| RPS | 1133.2 | 1135 |
|
||||
| Median Latency (ms) | 140 | 132 |
|
||||
|
109
docs/my-website/docs/completion/predict_outputs.md
Normal file
109
docs/my-website/docs/completion/predict_outputs.md
Normal file
|
@ -0,0 +1,109 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Predicted Outputs
|
||||
|
||||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | Use this when most of the output of the LLM is known ahead of time. For instance, if you are asking the model to rewrite some text or code with only minor changes, you can reduce your latency significantly by using Predicted Outputs, passing in the existing content as your prediction. |
|
||||
| Supported providers | `openai` |
|
||||
| Link to OpenAI doc on Predicted Outputs | [Predicted Outputs ↗](https://platform.openai.com/docs/guides/latency-optimization#use-predicted-outputs) |
|
||||
| Supported from LiteLLM Version | `v1.51.4` |
|
||||
|
||||
|
||||
|
||||
## Using Predicted Outputs
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="LiteLLM Python SDK" value="Python">
|
||||
|
||||
In this example we want to refactor a piece of C# code, and convert the Username property to Email instead:
|
||||
```python
|
||||
import litellm
|
||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||
code = """
|
||||
/// <summary>
|
||||
/// Represents a user with a first name, last name, and username.
|
||||
/// </summary>
|
||||
public class User
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the user's first name.
|
||||
/// </summary>
|
||||
public string FirstName { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the user's last name.
|
||||
/// </summary>
|
||||
public string LastName { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the user's username.
|
||||
/// </summary>
|
||||
public string Username { get; set; }
|
||||
}
|
||||
"""
|
||||
|
||||
completion = litellm.completion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
|
||||
},
|
||||
{"role": "user", "content": code},
|
||||
],
|
||||
prediction={"type": "content", "content": code},
|
||||
)
|
||||
|
||||
print(completion)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem label="LiteLLM Proxy Server" value="proxy">
|
||||
|
||||
1. Define models on config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o-mini # OpenAI gpt-4o-mini
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-mini
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
```
|
||||
|
||||
2. Run proxy server
|
||||
|
||||
```bash
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
3. Test it using the OpenAI Python SDK
|
||||
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="LITELLM_PROXY_KEY", # sk-1234
|
||||
base_url="LITELLM_PROXY_BASE" # http://0.0.0.0:4000
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
|
||||
},
|
||||
{"role": "user", "content": code},
|
||||
],
|
||||
prediction={"type": "content", "content": code},
|
||||
)
|
||||
|
||||
print(completion)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
|
@ -35,7 +35,7 @@ OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
|
|||
|
||||
```shell
|
||||
OTEL_EXPORTER="otlp_http"
|
||||
OTEL_ENDPOINT="http:/0.0.0.0:4317"
|
||||
OTEL_ENDPOINT="http://0.0.0.0:4318"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
@ -44,14 +44,24 @@ OTEL_ENDPOINT="http:/0.0.0.0:4317"
|
|||
|
||||
```shell
|
||||
OTEL_EXPORTER="otlp_grpc"
|
||||
OTEL_ENDPOINT="http:/0.0.0.0:4317"
|
||||
OTEL_ENDPOINT="http://0.0.0.0:4317"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="laminar" label="Log to Laminar">
|
||||
|
||||
```shell
|
||||
OTEL_EXPORTER="otlp_grpc"
|
||||
OTEL_ENDPOINT="https://api.lmnr.ai:8443"
|
||||
OTEL_HEADERS="authorization=Bearer <project-api-key>"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
Use just 2 lines of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
|
||||
Use just 1 line of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
|
||||
|
||||
```python
|
||||
litellm.callbacks = ["otel"]
|
||||
|
|
|
@ -864,3 +864,96 @@ Human: How do I boil water?
|
|||
Assistant:
|
||||
```
|
||||
|
||||
|
||||
## Usage - PDF
|
||||
|
||||
Pass base64 encoded PDF files to Anthropic models using the `image_url` field.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
### **using base64**
|
||||
```python
|
||||
from litellm import completion, supports_pdf_input
|
||||
import base64
|
||||
import requests
|
||||
|
||||
# URL of the file
|
||||
url = "https://storage.googleapis.com/cloud-samples-data/generative-ai/pdf/2403.05530.pdf"
|
||||
|
||||
# Download the file
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
|
||||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
|
||||
## check if model supports pdf input - (2024/11/11) only claude-3-5-haiku-20241022 supports it
|
||||
supports_pdf_input("anthropic/claude-3-5-haiku-20241022") # True
|
||||
|
||||
response = completion(
|
||||
model="anthropic/claude-3-5-haiku-20241022",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a very professional document summarization specialist. Please summarize the given document."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": f"data:application/pdf;base64,{encoded_file}", # 👈 PDF
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
print(response.choices[0])
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" lable="PROXY">
|
||||
|
||||
1. Add model to config
|
||||
|
||||
```yaml
|
||||
- model_name: claude-3-5-haiku-20241022
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-haiku-20241022
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
||||
```
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
|
||||
-d '{
|
||||
"model": "claude-3-5-haiku-20241022",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are a very professional document summarization specialist. Please summarize the given document"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:application/pdf;base64,{encoded_file}" # 👈 PDF
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 300
|
||||
}'
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
|
|
@ -1082,5 +1082,6 @@ print(f"response: {response}")
|
|||
|
||||
| Model Name | Function Call |
|
||||
|----------------------|---------------------------------------------|
|
||||
| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` |
|
||||
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` |
|
||||
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` |
|
|
@ -9,7 +9,7 @@ LiteLLM Supports Logging to the following Cloud Buckets
|
|||
- (Enterprise) ✨ [Google Cloud Storage Buckets](#logging-proxy-inputoutput-to-google-cloud-storage-buckets)
|
||||
- (Free OSS) [Amazon s3 Buckets](#logging-proxy-inputoutput---s3-buckets)
|
||||
|
||||
## Logging Proxy Input/Output to Google Cloud Storage Buckets
|
||||
## Google Cloud Storage Buckets
|
||||
|
||||
Log LLM Logs to [Google Cloud Storage Buckets](https://cloud.google.com/storage?hl=en)
|
||||
|
||||
|
@ -20,6 +20,14 @@ Log LLM Logs to [Google Cloud Storage Buckets](https://cloud.google.com/storage?
|
|||
:::
|
||||
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Description | Log LLM Input/Output to cloud storage buckets |
|
||||
| Load Test Benchmarks | [Benchmarks](https://docs.litellm.ai/docs/benchmarks) |
|
||||
| Google Docs on Cloud Storage | [Google Cloud Storage](https://cloud.google.com/storage?hl=en) |
|
||||
|
||||
|
||||
|
||||
### Usage
|
||||
|
||||
1. Add `gcs_bucket` to LiteLLM Config.yaml
|
||||
|
@ -85,7 +93,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
6. Save the JSON file and add the path to `GCS_PATH_SERVICE_ACCOUNT`
|
||||
|
||||
|
||||
## Logging Proxy Input/Output - s3 Buckets
|
||||
## s3 Buckets
|
||||
|
||||
We will use the `--config` to set
|
||||
|
||||
|
|
|
@ -692,9 +692,13 @@ general_settings:
|
|||
allowed_routes: ["route1", "route2"] # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
|
||||
key_management_system: google_kms # either google_kms or azure_kms
|
||||
master_key: string
|
||||
|
||||
# Database Settings
|
||||
database_url: string
|
||||
database_connection_pool_limit: 0 # default 100
|
||||
database_connection_timeout: 0 # default 60s
|
||||
allow_requests_on_db_unavailable: boolean # if true, will allow requests that can not connect to the DB to verify Virtual Key to still work
|
||||
|
||||
custom_auth: string
|
||||
max_parallel_requests: 0 # the max parallel requests allowed per deployment
|
||||
global_max_parallel_requests: 0 # the max parallel requests allowed on the proxy all up
|
||||
|
@ -766,6 +770,7 @@ general_settings:
|
|||
| database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) |
|
||||
| database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) |
|
||||
| database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) |
|
||||
| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key |
|
||||
| custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) |
|
||||
| max_parallel_requests | integer | The max parallel requests allowed per deployment |
|
||||
| global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall |
|
||||
|
@ -929,6 +934,8 @@ router_settings:
|
|||
| EMAIL_SUPPORT_CONTACT | Support contact email address
|
||||
| GCS_BUCKET_NAME | Name of the Google Cloud Storage bucket
|
||||
| GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file
|
||||
| GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS. **Default is 20 seconds**
|
||||
| GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs. **Default is 2048**
|
||||
| GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers
|
||||
| GENERIC_CLIENT_ID | Client ID for generic OAuth providers
|
||||
| GENERIC_CLIENT_SECRET | Client secret for generic OAuth providers
|
||||
|
|
|
@ -107,7 +107,7 @@ class StandardLoggingModelInformation(TypedDict):
|
|||
model_map_value: Optional[ModelInfo]
|
||||
```
|
||||
|
||||
## Logging Proxy Input/Output - Langfuse
|
||||
## Langfuse
|
||||
|
||||
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment
|
||||
|
||||
|
@ -463,7 +463,7 @@ You will see `raw_request` in your Langfuse Metadata. This is the RAW CURL comma
|
|||
|
||||
<Image img={require('../../img/debug_langfuse.png')} />
|
||||
|
||||
## Logging Proxy Input/Output in OpenTelemetry format
|
||||
## OpenTelemetry format
|
||||
|
||||
:::info
|
||||
|
||||
|
@ -1216,7 +1216,7 @@ litellm_settings:
|
|||
|
||||
Start the LiteLLM Proxy and make a test request to verify the logs reached your callback API
|
||||
|
||||
## Logging LLM IO to Langsmith
|
||||
## Langsmith
|
||||
|
||||
1. Set `success_callback: ["langsmith"]` on litellm config.yaml
|
||||
|
||||
|
@ -1261,7 +1261,7 @@ Expect to see your log on Langfuse
|
|||
<Image img={require('../../img/langsmith_new.png')} />
|
||||
|
||||
|
||||
## Logging LLM IO to Arize AI
|
||||
## Arize AI
|
||||
|
||||
1. Set `success_callback: ["arize"]` on litellm config.yaml
|
||||
|
||||
|
@ -1309,7 +1309,7 @@ Expect to see your log on Langfuse
|
|||
<Image img={require('../../img/langsmith_new.png')} />
|
||||
|
||||
|
||||
## Logging LLM IO to Langtrace
|
||||
## Langtrace
|
||||
|
||||
1. Set `success_callback: ["langtrace"]` on litellm config.yaml
|
||||
|
||||
|
@ -1351,7 +1351,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
'
|
||||
```
|
||||
|
||||
## Logging LLM IO to Galileo
|
||||
## Galileo
|
||||
|
||||
[BETA]
|
||||
|
||||
|
@ -1466,7 +1466,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
|
||||
<Image img={require('../../img/openmeter_img_2.png')} />
|
||||
|
||||
## Logging Proxy Input/Output - DataDog
|
||||
## DataDog
|
||||
|
||||
LiteLLM Supports logging to the following Datdog Integrations:
|
||||
- `datadog` [Datadog Logs](https://docs.datadoghq.com/logs/)
|
||||
|
@ -1543,7 +1543,7 @@ Expected output on Datadog
|
|||
|
||||
<Image img={require('../../img/dd_small1.png')} />
|
||||
|
||||
## Logging Proxy Input/Output - DynamoDB
|
||||
## DynamoDB
|
||||
|
||||
We will use the `--config` to set
|
||||
|
||||
|
@ -1669,7 +1669,7 @@ Your logs should be available on DynamoDB
|
|||
}
|
||||
```
|
||||
|
||||
## Logging Proxy Input/Output - Sentry
|
||||
## Sentry
|
||||
|
||||
If api calls fail (llm/database) you can log those to Sentry:
|
||||
|
||||
|
@ -1711,7 +1711,7 @@ Test Request
|
|||
litellm --test
|
||||
```
|
||||
|
||||
## Logging Proxy Input/Output Athina
|
||||
## Athina
|
||||
|
||||
[Athina](https://athina.ai/) allows you to log LLM Input/Output for monitoring, analytics, and observability.
|
||||
|
||||
|
|
|
@ -20,6 +20,10 @@ general_settings:
|
|||
proxy_batch_write_at: 60 # Batch write spend updates every 60s
|
||||
database_connection_pool_limit: 10 # limit the number of database connections to = MAX Number of DB Connections/Number of instances of litellm proxy (Around 10-20 is good number)
|
||||
|
||||
# OPTIONAL Best Practices
|
||||
disable_spend_logs: True # turn off writing each transaction to the db. We recommend doing this is you don't need to see Usage on the LiteLLM UI and are tracking metrics via Prometheus
|
||||
allow_requests_on_db_unavailable: True # Only USE when running LiteLLM on your VPC. Allow requests to still be processed even if the DB is unavailable. We recommend doing this if you're running LiteLLM on VPC that cannot be accessed from the public internet.
|
||||
|
||||
litellm_settings:
|
||||
request_timeout: 600 # raise Timeout error if call takes longer than 600 seconds. Default value is 6000seconds if not set
|
||||
set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
|
||||
|
@ -86,7 +90,29 @@ Set `export LITELLM_MODE="PRODUCTION"`
|
|||
|
||||
This disables the load_dotenv() functionality, which will automatically load your environment credentials from the local `.env`.
|
||||
|
||||
## 5. Set LiteLLM Salt Key
|
||||
## 5. If running LiteLLM on VPC, gracefully handle DB unavailability
|
||||
|
||||
This will allow LiteLLM to continue to process requests even if the DB is unavailable. This is better handling for DB unavailability.
|
||||
|
||||
**WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.**
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
allow_requests_on_db_unavailable: True
|
||||
```
|
||||
|
||||
## 6. Disable spend_logs if you're not using the LiteLLM UI
|
||||
|
||||
By default LiteLLM will write every request to the `LiteLLM_SpendLogs` table. This is used for viewing Usage on the LiteLLM UI.
|
||||
|
||||
If you're not viewing Usage on the LiteLLM UI (most users use Prometheus when this is disabled), you can disable spend_logs by setting `disable_spend_logs` to `True`.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_spend_logs: True
|
||||
```
|
||||
|
||||
## 7. Set LiteLLM Salt Key
|
||||
|
||||
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ Possible values for `budget_duration`
|
|||
| `budget_duration="1m"` | every 1 min |
|
||||
| `budget_duration="1h"` | every 1 hour |
|
||||
| `budget_duration="1d"` | every 1 day |
|
||||
| `budget_duration="1mo"` | every 1 month |
|
||||
| `budget_duration="30d"` | every 1 month |
|
||||
|
||||
|
||||
### 2. Create a key for the `team`
|
||||
|
|
|
@ -30,11 +30,11 @@ This config would send langfuse logs to 2 different langfuse projects, based on
|
|||
```yaml
|
||||
litellm_settings:
|
||||
default_team_settings:
|
||||
- team_id: my-secret-project
|
||||
- team_id: "dbe2f686-a686-4896-864a-4c3924458709"
|
||||
success_callback: ["langfuse"]
|
||||
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
|
||||
langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
|
||||
- team_id: ishaans-secret-project
|
||||
- team_id: "06ed1e01-3fa7-4b9e-95bc-f2e59b74f3a8"
|
||||
success_callback: ["langfuse"]
|
||||
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_2 # Project 2
|
||||
langfuse_secret: os.environ/LANGFUSE_SECRET_2 # Project 2
|
||||
|
@ -46,7 +46,7 @@ Now, when you [generate keys](./virtual_keys.md) for this team-id
|
|||
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"team_id": "ishaans-secret-project"}'
|
||||
-d '{"team_id": "06ed1e01-3fa7-4b9e-95bc-f2e59b74f3a8"}'
|
||||
```
|
||||
|
||||
All requests made with these keys will log data to their team-specific logging. -->
|
||||
|
@ -281,6 +281,51 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
|||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem label="Langsmith" value="langsmith">
|
||||
|
||||
1. Create Virtual Key to log to a specific Langsmith Project
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"metadata": {
|
||||
"logging": [{
|
||||
"callback_name": "langsmith", # "otel", "gcs_bucket"
|
||||
"callback_type": "success", # "success", "failure", "success_and_failure"
|
||||
"callback_vars": {
|
||||
"langsmith_api_key": "os.environ/LANGSMITH_API_KEY", # API Key for Langsmith logging
|
||||
"langsmith_project": "pr-brief-resemblance-72", # project name on langsmith
|
||||
"langsmith_base_url": "https://api.smith.langchain.com"
|
||||
}
|
||||
}]
|
||||
}
|
||||
}'
|
||||
|
||||
```
|
||||
|
||||
2. Test it - `/chat/completions` request
|
||||
|
||||
Use the virtual key from step 3 to make a `/chat/completions` request
|
||||
|
||||
You should see your logs on your Langsmith project on a successful request
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-Fxq5XSyWKeXDKfPdqXZhPg" \
|
||||
-d '{
|
||||
"model": "fake-openai-endpoint",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude"}
|
||||
],
|
||||
"user": "hello",
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
|
|
@ -205,6 +205,7 @@ const sidebars = {
|
|||
"completion/prompt_caching",
|
||||
"completion/audio",
|
||||
"completion/vision",
|
||||
"completion/predict_outputs",
|
||||
"completion/prefix",
|
||||
"completion/drop_params",
|
||||
"completion/prompt_formatting",
|
||||
|
@ -265,6 +266,7 @@ const sidebars = {
|
|||
type: "category",
|
||||
label: "Load Testing",
|
||||
items: [
|
||||
"benchmarks",
|
||||
"load_test",
|
||||
"load_test_advanced",
|
||||
"load_test_sdk",
|
||||
|
|
|
@ -137,6 +137,8 @@ safe_memory_mode: bool = False
|
|||
enable_azure_ad_token_refresh: Optional[bool] = False
|
||||
### DEFAULT AZURE API VERSION ###
|
||||
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
|
||||
### DEFAULT WATSONX API VERSION ###
|
||||
WATSONX_DEFAULT_API_VERSION = "2024-03-13"
|
||||
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
||||
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
|
||||
### GUARDRAILS ###
|
||||
|
@ -282,7 +284,9 @@ priority_reservation: Optional[Dict[str, float]] = None
|
|||
#### RELIABILITY ####
|
||||
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
||||
request_timeout: float = 6000 # time in seconds
|
||||
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
||||
module_level_aclient = AsyncHTTPHandler(
|
||||
timeout=request_timeout, client_alias="module level aclient"
|
||||
)
|
||||
module_level_client = HTTPHandler(timeout=request_timeout)
|
||||
num_retries: Optional[int] = None # per model endpoint
|
||||
max_fallbacks: Optional[int] = None
|
||||
|
@ -371,6 +375,7 @@ open_ai_text_completion_models: List = []
|
|||
cohere_models: List = []
|
||||
cohere_chat_models: List = []
|
||||
mistral_chat_models: List = []
|
||||
text_completion_codestral_models: List = []
|
||||
anthropic_models: List = []
|
||||
empower_models: List = []
|
||||
openrouter_models: List = []
|
||||
|
@ -397,6 +402,19 @@ deepinfra_models: List = []
|
|||
perplexity_models: List = []
|
||||
watsonx_models: List = []
|
||||
gemini_models: List = []
|
||||
xai_models: List = []
|
||||
deepseek_models: List = []
|
||||
azure_ai_models: List = []
|
||||
voyage_models: List = []
|
||||
databricks_models: List = []
|
||||
cloudflare_models: List = []
|
||||
codestral_models: List = []
|
||||
friendliai_models: List = []
|
||||
palm_models: List = []
|
||||
groq_models: List = []
|
||||
azure_models: List = []
|
||||
anyscale_models: List = []
|
||||
cerebras_models: List = []
|
||||
|
||||
|
||||
def add_known_models():
|
||||
|
@ -473,6 +491,34 @@ def add_known_models():
|
|||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||
if "-to-" not in key:
|
||||
fireworks_ai_embedding_models.append(key)
|
||||
elif value.get("litellm_provider") == "text-completion-codestral":
|
||||
text_completion_codestral_models.append(key)
|
||||
elif value.get("litellm_provider") == "xai":
|
||||
xai_models.append(key)
|
||||
elif value.get("litellm_provider") == "deepseek":
|
||||
deepseek_models.append(key)
|
||||
elif value.get("litellm_provider") == "azure_ai":
|
||||
azure_ai_models.append(key)
|
||||
elif value.get("litellm_provider") == "voyage":
|
||||
voyage_models.append(key)
|
||||
elif value.get("litellm_provider") == "databricks":
|
||||
databricks_models.append(key)
|
||||
elif value.get("litellm_provider") == "cloudflare":
|
||||
cloudflare_models.append(key)
|
||||
elif value.get("litellm_provider") == "codestral":
|
||||
codestral_models.append(key)
|
||||
elif value.get("litellm_provider") == "friendliai":
|
||||
friendliai_models.append(key)
|
||||
elif value.get("litellm_provider") == "palm":
|
||||
palm_models.append(key)
|
||||
elif value.get("litellm_provider") == "groq":
|
||||
groq_models.append(key)
|
||||
elif value.get("litellm_provider") == "azure":
|
||||
azure_models.append(key)
|
||||
elif value.get("litellm_provider") == "anyscale":
|
||||
anyscale_models.append(key)
|
||||
elif value.get("litellm_provider") == "cerebras":
|
||||
cerebras_models.append(key)
|
||||
|
||||
|
||||
add_known_models()
|
||||
|
@ -527,7 +573,11 @@ openai_text_completion_compatible_providers: List = (
|
|||
"hosted_vllm",
|
||||
]
|
||||
)
|
||||
|
||||
_openai_like_providers: List = [
|
||||
"predibase",
|
||||
"databricks",
|
||||
"watsonx",
|
||||
] # private helper. similar to openai but require some custom auth / endpoint handling, so can't use the openai sdk
|
||||
# well supported replicate llms
|
||||
replicate_models: List = [
|
||||
# llama replicate supported LLMs
|
||||
|
@ -714,6 +764,20 @@ model_list = (
|
|||
+ vertex_language_models
|
||||
+ watsonx_models
|
||||
+ gemini_models
|
||||
+ text_completion_codestral_models
|
||||
+ xai_models
|
||||
+ deepseek_models
|
||||
+ azure_ai_models
|
||||
+ voyage_models
|
||||
+ databricks_models
|
||||
+ cloudflare_models
|
||||
+ codestral_models
|
||||
+ friendliai_models
|
||||
+ palm_models
|
||||
+ groq_models
|
||||
+ azure_models
|
||||
+ anyscale_models
|
||||
+ cerebras_models
|
||||
)
|
||||
|
||||
|
||||
|
@ -770,6 +834,7 @@ class LlmProviders(str, Enum):
|
|||
FIREWORKS_AI = "fireworks_ai"
|
||||
FRIENDLIAI = "friendliai"
|
||||
WATSONX = "watsonx"
|
||||
WATSONX_TEXT = "watsonx_text"
|
||||
TRITON = "triton"
|
||||
PREDIBASE = "predibase"
|
||||
DATABRICKS = "databricks"
|
||||
|
@ -786,6 +851,7 @@ provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
|||
|
||||
models_by_provider: dict = {
|
||||
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,
|
||||
"text-completion-openai": open_ai_text_completion_models,
|
||||
"cohere": cohere_models + cohere_chat_models,
|
||||
"cohere_chat": cohere_chat_models,
|
||||
"anthropic": anthropic_models,
|
||||
|
@ -809,6 +875,23 @@ models_by_provider: dict = {
|
|||
"watsonx": watsonx_models,
|
||||
"gemini": gemini_models,
|
||||
"fireworks_ai": fireworks_ai_models + fireworks_ai_embedding_models,
|
||||
"aleph_alpha": aleph_alpha_models,
|
||||
"text-completion-codestral": text_completion_codestral_models,
|
||||
"xai": xai_models,
|
||||
"deepseek": deepseek_models,
|
||||
"mistral": mistral_chat_models,
|
||||
"azure_ai": azure_ai_models,
|
||||
"voyage": voyage_models,
|
||||
"databricks": databricks_models,
|
||||
"cloudflare": cloudflare_models,
|
||||
"codestral": codestral_models,
|
||||
"nlp_cloud": nlp_cloud_models,
|
||||
"friendliai": friendliai_models,
|
||||
"palm": palm_models,
|
||||
"groq": groq_models,
|
||||
"azure": azure_models,
|
||||
"anyscale": anyscale_models,
|
||||
"cerebras": cerebras_models,
|
||||
}
|
||||
|
||||
# mapping for those models which have larger equivalents
|
||||
|
@ -881,7 +964,6 @@ from .utils import (
|
|||
supports_system_messages,
|
||||
get_litellm_params,
|
||||
acreate,
|
||||
get_model_list,
|
||||
get_max_tokens,
|
||||
get_model_info,
|
||||
register_prompt_template,
|
||||
|
@ -976,10 +1058,11 @@ from .llms.bedrock.common_utils import (
|
|||
AmazonAnthropicClaude3Config,
|
||||
AmazonCohereConfig,
|
||||
AmazonLlamaConfig,
|
||||
AmazonStabilityConfig,
|
||||
AmazonMistralConfig,
|
||||
AmazonBedrockGlobalConfig,
|
||||
)
|
||||
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
|
||||
from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config
|
||||
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
|
||||
AmazonTitanMultimodalEmbeddingG1Config,
|
||||
|
@ -1037,10 +1120,12 @@ from .llms.AzureOpenAI.azure import (
|
|||
|
||||
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
|
||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
||||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
|
||||
from .llms.watsonx import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.completion.handler import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||
from .main import * # type: ignore
|
||||
from .integrations import *
|
||||
from .exceptions import (
|
||||
|
|
|
@ -241,7 +241,8 @@ class ServiceLogging(CustomLogger):
|
|||
if callback == "prometheus_system":
|
||||
await self.init_prometheus_services_logger_if_none()
|
||||
await self.prometheusServicesLogger.async_service_failure_hook(
|
||||
payload=payload
|
||||
payload=payload,
|
||||
error=error,
|
||||
)
|
||||
elif callback == "datadog":
|
||||
await self.init_datadog_logger_if_none()
|
||||
|
|
|
@ -8,6 +8,7 @@ Has 4 methods:
|
|||
- async_get_cache
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -18,7 +19,7 @@ else:
|
|||
Span = Any
|
||||
|
||||
|
||||
class BaseCache:
|
||||
class BaseCache(ABC):
|
||||
def __init__(self, default_ttl: int = 60):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
|
@ -37,6 +38,10 @@ class BaseCache:
|
|||
async def async_set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
pass
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -233,19 +233,18 @@ class Cache:
|
|||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, *args, **kwargs) -> str:
|
||||
def get_cache_key(self, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
|
@ -521,7 +520,7 @@ class Cache:
|
|||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, *args, **kwargs):
|
||||
def get_cache(self, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
||||
|
@ -533,13 +532,13 @@ class Cache:
|
|||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
|
@ -553,29 +552,28 @@ class Cache:
|
|||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
async def async_get_cache(self, **kwargs):
|
||||
"""
|
||||
Async get cache implementation.
|
||||
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
|
||||
try: # never block execution
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
|
@ -583,7 +581,7 @@ class Cache:
|
|||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def _add_cache_logic(self, result, *args, **kwargs):
|
||||
def _add_cache_logic(self, result, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
|
@ -591,7 +589,7 @@ class Cache:
|
|||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump_json()
|
||||
|
@ -613,7 +611,7 @@ class Cache:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, *args, **kwargs):
|
||||
def add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
|
@ -625,41 +623,42 @@ class Cache:
|
|||
None
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
result=result, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache(self, result, *args, **kwargs):
|
||||
async def async_add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
if self.type == "redis" and self.redis_flush_size is not None:
|
||||
# high traffic - fill in results in memory and then flush
|
||||
await self.batch_cache_write(result, *args, **kwargs)
|
||||
await self.batch_cache_write(result, **kwargs)
|
||||
else:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
result=result, **kwargs
|
||||
)
|
||||
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache_pipeline(self, result, *args, **kwargs):
|
||||
async def async_add_cache_pipeline(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
# set default ttl if not set
|
||||
|
@ -668,29 +667,27 @@ class Cache:
|
|||
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_response,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
async_set_cache_pipeline = getattr(
|
||||
self.cache, "async_set_cache_pipeline", None
|
||||
)
|
||||
if async_set_cache_pipeline:
|
||||
await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
else:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
# if async_set_cache_pipeline:
|
||||
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
# else:
|
||||
# tasks = []
|
||||
# for val in cache_list:
|
||||
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
||||
# await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def should_use_cache(self, *args, **kwargs):
|
||||
def should_use_cache(self, **kwargs):
|
||||
"""
|
||||
Returns true if we should use the cache for LLM API calls
|
||||
|
||||
|
@ -708,10 +705,8 @@ class Cache:
|
|||
return True
|
||||
return False
|
||||
|
||||
async def batch_cache_write(self, result, *args, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
)
|
||||
async def batch_cache_write(self, result, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
||||
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||
|
||||
async def ping(self):
|
||||
|
|
|
@ -137,7 +137,7 @@ class LLMCachingHandler:
|
|||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
verbose_logger.debug("Checking Cache")
|
||||
cached_result = await self._retrieve_from_cache(
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
|
@ -145,7 +145,7 @@ class LLMCachingHandler:
|
|||
)
|
||||
|
||||
if cached_result is not None and not isinstance(cached_result, list):
|
||||
print_verbose("Cache Hit!")
|
||||
verbose_logger.debug("Cache Hit!")
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
model, _, _, _ = litellm.get_llm_provider(
|
||||
|
@ -215,6 +215,7 @@ class LLMCachingHandler:
|
|||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
|
||||
)
|
||||
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
|
||||
return CachingHandlerResponse(
|
||||
cached_result=cached_result,
|
||||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
|
@ -233,12 +234,19 @@ class LLMCachingHandler:
|
|||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
args = args or ()
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
cached_result: Optional[Any] = None
|
||||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
if cached_result is not None:
|
||||
if "detail" in cached_result:
|
||||
# implies an error occurred
|
||||
|
@ -475,14 +483,21 @@ class LLMCachingHandler:
|
|||
if litellm.cache is None:
|
||||
return None
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
cached_result: Optional[Any] = None
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
new_kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
for idx, i in enumerate(new_kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
**{**new_kwargs, "input": i}
|
||||
)
|
||||
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
|
@ -493,9 +508,9 @@ class LLMCachingHandler:
|
|||
cached_result = None
|
||||
else:
|
||||
if litellm.cache._supports_async() is True:
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_result_to_model_response(
|
||||
|
@ -580,6 +595,7 @@ class LLMCachingHandler:
|
|||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
|
||||
elif (
|
||||
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
|
||||
) and isinstance(cached_result, dict):
|
||||
|
@ -603,6 +619,13 @@ class LLMCachingHandler:
|
|||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(cached_result, "_hidden_params")
|
||||
and cached_result._hidden_params is not None
|
||||
and isinstance(cached_result._hidden_params, dict)
|
||||
):
|
||||
cached_result._hidden_params["cache_hit"] = True
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_stream_response(
|
||||
|
@ -658,12 +681,19 @@ class LLMCachingHandler:
|
|||
Raises:
|
||||
None
|
||||
"""
|
||||
kwargs.update(convert_args_to_kwargs(result, original_function, kwargs, args))
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
if litellm.cache is None:
|
||||
return
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if self._should_store_result_in_cache(
|
||||
original_function=original_function, kwargs=kwargs
|
||||
original_function=original_function, kwargs=new_kwargs
|
||||
):
|
||||
if (
|
||||
isinstance(result, litellm.ModelResponse)
|
||||
|
@ -673,29 +703,29 @@ class LLMCachingHandler:
|
|||
):
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and isinstance(kwargs["input"], list)
|
||||
and isinstance(new_kwargs["input"], list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache_pipeline(result, **kwargs)
|
||||
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
|
||||
)
|
||||
elif isinstance(litellm.cache.cache, S3Cache):
|
||||
threading.Thread(
|
||||
target=litellm.cache.add_cache,
|
||||
args=(result,),
|
||||
kwargs=kwargs,
|
||||
kwargs=new_kwargs,
|
||||
).start()
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(
|
||||
result.model_dump_json(), **kwargs
|
||||
result.model_dump_json(), **new_kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(litellm.cache.async_add_cache(result, **kwargs))
|
||||
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
|
||||
|
||||
def sync_set_cache(
|
||||
self,
|
||||
|
@ -706,16 +736,20 @@ class LLMCachingHandler:
|
|||
"""
|
||||
Sync internal method to add the result to the cache
|
||||
"""
|
||||
kwargs.update(
|
||||
convert_args_to_kwargs(result, self.original_function, kwargs, args)
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
if litellm.cache is None:
|
||||
return
|
||||
|
||||
if self._should_store_result_in_cache(
|
||||
original_function=self.original_function, kwargs=kwargs
|
||||
original_function=self.original_function, kwargs=new_kwargs
|
||||
):
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
litellm.cache.add_cache(result, **new_kwargs)
|
||||
|
||||
return
|
||||
|
||||
|
@ -865,9 +899,7 @@ class LLMCachingHandler:
|
|||
|
||||
|
||||
def convert_args_to_kwargs(
|
||||
result: Any,
|
||||
original_function: Callable,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Get the signature of the original function
|
||||
|
|
|
@ -24,7 +24,6 @@ class DiskCache(BaseCache):
|
|||
self.disk_cache = dc.Cache(disk_cache_dir)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose("DiskCache: set_cache")
|
||||
if "ttl" in kwargs:
|
||||
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
||||
else:
|
||||
|
@ -33,10 +32,10 @@ class DiskCache(BaseCache):
|
|||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
if "ttl" in kwargs:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ class DualCache(BaseCache):
|
|||
self.redis_batch_cache_expiry = (
|
||||
default_redis_batch_cache_expiry
|
||||
or litellm.default_redis_batch_cache_expiry
|
||||
or 5
|
||||
or 10
|
||||
)
|
||||
self.default_in_memory_ttl = (
|
||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||
|
@ -314,7 +314,8 @@ class DualCache(BaseCache):
|
|||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_batch_set_cache(
|
||||
# async_batch_set_cache
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: list, local_only: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -9,6 +9,7 @@ Has 4 methods:
|
|||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
@ -422,3 +423,9 @@ class QdrantSemanticCache(BaseCache):
|
|||
|
||||
async def _collection_info(self):
|
||||
return self.collection_info
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
@ -404,7 +404,7 @@ class RedisCache(BaseCache):
|
|||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||
)
|
||||
)
|
||||
return results
|
||||
return None
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
|
|
|
@ -9,6 +9,7 @@ Has 4 methods:
|
|||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
@ -331,3 +332,9 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
async def _index_info(self):
|
||||
return await self.index.ainfo()
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
@ -10,6 +10,7 @@ Has 4 methods:
|
|||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
|
@ -153,3 +154,9 @@ class S3Cache(BaseCache):
|
|||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
@ -28,6 +28,9 @@ from litellm.llms.azure_ai.cost_calculator import (
|
|||
from litellm.llms.AzureOpenAI.cost_calculation import (
|
||||
cost_per_token as azure_openai_cost_per_token,
|
||||
)
|
||||
from litellm.llms.bedrock.image.cost_calculator import (
|
||||
cost_calculator as bedrock_image_cost_calculator,
|
||||
)
|
||||
from litellm.llms.cohere.cost_calculator import (
|
||||
cost_per_query as cohere_rerank_cost_per_query,
|
||||
)
|
||||
|
@ -521,12 +524,13 @@ def completion_cost( # noqa: PLR0915
|
|||
custom_llm_provider=None,
|
||||
region_name=None, # used for bedrock pricing
|
||||
### IMAGE GEN ###
|
||||
size=None,
|
||||
size: Optional[str] = None,
|
||||
quality=None,
|
||||
n=None, # number of images
|
||||
### CUSTOM PRICING ###
|
||||
custom_cost_per_token: Optional[CostPerToken] = None,
|
||||
custom_cost_per_second: Optional[float] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
|
||||
|
@ -667,7 +671,17 @@ def completion_cost( # noqa: PLR0915
|
|||
# https://cloud.google.com/vertex-ai/generative-ai/pricing
|
||||
# Vertex Charges Flat $0.20 per image
|
||||
return 0.020
|
||||
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if isinstance(completion_response, ImageResponse):
|
||||
return bedrock_image_cost_calculator(
|
||||
model=model,
|
||||
size=size,
|
||||
image_response=completion_response,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
raise TypeError(
|
||||
"completion_response must be of type ImageResponse for bedrock image cost calculation"
|
||||
)
|
||||
if size is None:
|
||||
size = "1024-x-1024" # openai default
|
||||
# fix size to match naming convention
|
||||
|
@ -677,9 +691,9 @@ def completion_cost( # noqa: PLR0915
|
|||
image_gen_model_name_with_quality = image_gen_model_name
|
||||
if quality is not None:
|
||||
image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}"
|
||||
size = size.split("-x-")
|
||||
height = int(size[0]) # if it's 1024-x-1024 vs. 1024x1024
|
||||
width = int(size[1])
|
||||
size_parts = size.split("-x-")
|
||||
height = int(size_parts[0]) # if it's 1024-x-1024 vs. 1024x1024
|
||||
width = int(size_parts[1])
|
||||
verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}")
|
||||
verbose_logger.debug(
|
||||
f"image_gen_model_name_with_quality: {image_gen_model_name_with_quality}"
|
||||
|
@ -844,6 +858,7 @@ def response_cost_calculator(
|
|||
model=model,
|
||||
call_type=call_type,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
else:
|
||||
if custom_pricing is True: # override defaults if custom pricing is set
|
||||
|
|
|
@ -423,7 +423,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
latency_cache_keys = [(key, 0) for key in latency_keys]
|
||||
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
|
||||
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
|
||||
await self.internal_usage_cache.async_batch_set_cache(
|
||||
await self.internal_usage_cache.async_set_cache_pipeline(
|
||||
cache_list=combined_metrics_cache_keys
|
||||
)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.utils import get_formatted_prompt
|
||||
from litellm.utils import get_formatted_prompt, print_verbose
|
||||
|
||||
global_braintrust_http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
|
@ -229,6 +229,9 @@ class BraintrustLogger(CustomLogger):
|
|||
request_data["metrics"] = metrics
|
||||
|
||||
try:
|
||||
print_verbose(
|
||||
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
|
||||
)
|
||||
global_braintrust_sync_http_handler.post(
|
||||
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||
json={"events": [request_data]},
|
||||
|
|
|
@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
|
|||
self,
|
||||
flush_lock: Optional[asyncio.Lock] = None,
|
||||
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
|
||||
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
|
|||
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
||||
"""
|
||||
self.log_queue: List = []
|
||||
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
|
||||
self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
|
||||
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
|
||||
self.last_flush_time = time.time()
|
||||
self.flush_lock = flush_lock
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
@ -10,10 +11,12 @@ from pydantic import BaseModel, Field
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import (
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingMetadata,
|
||||
|
@ -26,10 +29,9 @@ else:
|
|||
VertexBase = Any
|
||||
|
||||
|
||||
class GCSLoggingConfig(TypedDict):
|
||||
bucket_name: str
|
||||
vertex_instance: VertexBase
|
||||
path_service_account: str
|
||||
IAM_AUTH_KEY = "IAM_AUTH"
|
||||
GCS_DEFAULT_BATCH_SIZE = 2048
|
||||
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
||||
|
||||
|
||||
class GCSBucketLogger(GCSBucketBase):
|
||||
|
@ -38,6 +40,21 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
|
||||
super().__init__(bucket_name=bucket_name)
|
||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||
|
||||
# Init Batch logging settings
|
||||
self.log_queue: List[GCSLogQueueItem] = []
|
||||
self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
|
||||
self.flush_interval = int(
|
||||
os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
|
||||
)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(
|
||||
flush_lock=self.flush_lock,
|
||||
batch_size=self.batch_size,
|
||||
flush_interval=self.flush_interval,
|
||||
)
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
|
@ -57,54 +74,23 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Modify the object_name to include the date-based folder
|
||||
object_name = f"{current_date}/{response_obj['id']}"
|
||||
try:
|
||||
response = await self.async_httpx_client.post(
|
||||
headers=headers,
|
||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||
data=json_logged_payload,
|
||||
# Add to logging queue - this will be flushed periodically
|
||||
self.log_queue.append(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"GCS Bucket logging error: {e.response.text}")
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||
|
||||
verbose_logger.debug("GCS Bucket response %s", response)
|
||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
|
||||
|
@ -112,51 +98,138 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
response_obj,
|
||||
)
|
||||
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
_litellm_params = kwargs.get("litellm_params") or {}
|
||||
metadata = _litellm_params.get("metadata") or {}
|
||||
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Modify the object_name to include the date-based folder
|
||||
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
|
||||
|
||||
if "gcs_log_id" in metadata:
|
||||
object_name = metadata["gcs_log_id"]
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
headers=headers,
|
||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||
data=json_logged_payload,
|
||||
# Add to logging queue - this will be flushed periodically
|
||||
self.log_queue.append(
|
||||
GCSLogQueueItem(
|
||||
payload=logging_payload, kwargs=kwargs, response_obj=response_obj
|
||||
)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||
|
||||
verbose_logger.debug("GCS Bucket response %s", response)
|
||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Process queued logs in batch - sends logs to GCS Bucket
|
||||
|
||||
|
||||
GCS Bucket does not have a Batch endpoint to batch upload logs
|
||||
|
||||
Instead, we
|
||||
- collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
|
||||
- during async_send_batch, we make 1 POST request per log to GCS Bucket
|
||||
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
try:
|
||||
for log_item in self.log_queue:
|
||||
logging_payload = log_item["payload"]
|
||||
kwargs = log_item["kwargs"]
|
||||
response_obj = log_item.get("response_obj", None) or {}
|
||||
|
||||
gcs_logging_config: GCSLoggingConfig = (
|
||||
await self.get_gcs_logging_config(kwargs)
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
object_name = self._get_object_name(
|
||||
kwargs, logging_payload, response_obj
|
||||
)
|
||||
await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=logging_payload,
|
||||
)
|
||||
|
||||
# Clear the queue after processing
|
||||
self.log_queue.clear()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"GCS Bucket batch logging error: {str(e)}")
|
||||
|
||||
def _get_object_name(
|
||||
self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name to use for the current payload
|
||||
"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
if logging_payload.get("error_str", None) is not None:
|
||||
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
|
||||
else:
|
||||
object_name = f"{current_date}/{response_obj.get('id', '')}"
|
||||
|
||||
# used for testing
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
_metadata = _litellm_params.get("metadata", None) or {}
|
||||
if "gcs_log_id" in _metadata:
|
||||
object_name = _metadata["gcs_log_id"]
|
||||
|
||||
return object_name
|
||||
|
||||
def _handle_folders_in_bucket_name(
|
||||
self,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Handles when the user passes a bucket name with a folder postfix
|
||||
|
||||
|
||||
Example:
|
||||
- Bucket name: "my-bucket/my-folder/dev"
|
||||
- Object name: "my-object"
|
||||
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
||||
|
||||
"""
|
||||
if "/" in bucket_name:
|
||||
bucket_name, prefix = bucket_name.split("/", 1)
|
||||
object_name = f"{prefix}/{object_name}"
|
||||
return bucket_name, object_name
|
||||
return bucket_name, object_name
|
||||
|
||||
async def _log_json_data_on_gcs(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
logging_payload: StandardLoggingPayload,
|
||||
):
|
||||
"""
|
||||
Helper function to make POST request to GCS Bucket in the specified bucket.
|
||||
"""
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
headers=headers,
|
||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||
data=json_logged_payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||
|
||||
verbose_logger.debug("GCS Bucket response %s", response)
|
||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||
|
||||
async def get_gcs_logging_config(
|
||||
self, kwargs: Optional[Dict[str, Any]] = {}
|
||||
) -> GCSLoggingConfig:
|
||||
|
@ -173,7 +246,7 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
)
|
||||
|
||||
bucket_name: str
|
||||
path_service_account: str
|
||||
path_service_account: Optional[str]
|
||||
if standard_callback_dynamic_params is not None:
|
||||
verbose_logger.debug("Using dynamic GCS logging")
|
||||
verbose_logger.debug(
|
||||
|
@ -193,10 +266,6 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
if _path_service_account is None:
|
||||
raise ValueError(
|
||||
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
|
||||
)
|
||||
bucket_name = _bucket_name
|
||||
path_service_account = _path_service_account
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
|
@ -208,10 +277,6 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
if self.path_service_account_json is None:
|
||||
raise ValueError(
|
||||
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
|
||||
)
|
||||
bucket_name = self.BUCKET_NAME
|
||||
path_service_account = self.path_service_account_json
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
|
@ -224,7 +289,9 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
path_service_account=path_service_account,
|
||||
)
|
||||
|
||||
async def get_or_create_vertex_instance(self, credentials: str) -> VertexBase:
|
||||
async def get_or_create_vertex_instance(
|
||||
self, credentials: Optional[str]
|
||||
) -> VertexBase:
|
||||
"""
|
||||
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
||||
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
||||
|
@ -233,15 +300,27 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
VertexBase,
|
||||
)
|
||||
|
||||
if credentials not in self.vertex_instances:
|
||||
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
||||
if _in_memory_key not in self.vertex_instances:
|
||||
vertex_instance = VertexBase()
|
||||
await vertex_instance._ensure_access_token_async(
|
||||
credentials=credentials,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
self.vertex_instances[credentials] = vertex_instance
|
||||
return self.vertex_instances[credentials]
|
||||
self.vertex_instances[_in_memory_key] = vertex_instance
|
||||
return self.vertex_instances[_in_memory_key]
|
||||
|
||||
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
||||
"""
|
||||
Returns key to use for caching the Vertex instance in-memory.
|
||||
|
||||
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
||||
|
||||
- If a credentials string is provided, it is used as the key.
|
||||
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
||||
"""
|
||||
return credentials or IAM_AUTH_KEY
|
||||
|
||||
async def download_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
|
@ -258,6 +337,11 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
|
||||
# Send the GET request to download the object
|
||||
|
@ -293,6 +377,11 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
||||
|
||||
# Send the DELETE request to delete the object
|
||||
|
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
|
@ -21,8 +21,8 @@ else:
|
|||
VertexBase = Any
|
||||
|
||||
|
||||
class GCSBucketBase(CustomLogger):
|
||||
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||
class GCSBucketBase(CustomBatchLogger):
|
||||
def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
@ -30,10 +30,11 @@ class GCSBucketBase(CustomLogger):
|
|||
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
||||
self.path_service_account_json: Optional[str] = _path_service_account
|
||||
self.BUCKET_NAME: Optional[str] = _bucket_name
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def construct_request_headers(
|
||||
self,
|
||||
service_account_json: str,
|
||||
service_account_json: Optional[str],
|
||||
vertex_instance: Optional[VertexBase] = None,
|
||||
) -> Dict[str, str]:
|
||||
from litellm import vertex_chat_completion
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
import types
|
||||
from collections.abc import MutableMapping, MutableSequence, MutableSet
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import BaseModel
|
||||
|
@ -14,7 +15,7 @@ from litellm._logging import verbose_logger
|
|||
from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.integrations.langfuse import *
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
|
||||
|
@ -355,6 +356,73 @@ class LangFuseLogger:
|
|||
)
|
||||
)
|
||||
|
||||
def is_base_type(self, value: Any) -> bool:
|
||||
# Check if the value is of a base type
|
||||
base_types = (int, float, str, bool, list, dict, tuple)
|
||||
return isinstance(value, base_types)
|
||||
|
||||
def _prepare_metadata(self, metadata: Optional[dict]) -> Any:
|
||||
try:
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
# Filter out function types from the metadata
|
||||
sanitized_metadata = {k: v for k, v in metadata.items() if not callable(v)}
|
||||
|
||||
return copy.deepcopy(sanitized_metadata)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Langfuse Layer Error - {e}, metadata: {metadata}")
|
||||
|
||||
new_metadata: Dict[str, Any] = {}
|
||||
|
||||
# if metadata is not a MutableMapping, return an empty dict since we can't call items() on it
|
||||
if not isinstance(metadata, MutableMapping):
|
||||
verbose_logger.debug(
|
||||
"Langfuse Layer Logging - metadata is not a MutableMapping, returning empty dict"
|
||||
)
|
||||
return new_metadata
|
||||
|
||||
for key, value in metadata.items():
|
||||
try:
|
||||
if isinstance(value, MutableMapping):
|
||||
new_metadata[key] = self._prepare_metadata(cast(dict, value))
|
||||
elif isinstance(value, MutableSequence):
|
||||
# For lists or other mutable sequences
|
||||
new_metadata[key] = list(
|
||||
(
|
||||
self._prepare_metadata(cast(dict, v))
|
||||
if isinstance(v, MutableMapping)
|
||||
else copy.deepcopy(v)
|
||||
)
|
||||
for v in value
|
||||
)
|
||||
elif isinstance(value, MutableSet):
|
||||
# For sets specifically, create a new set by passing an iterable
|
||||
new_metadata[key] = set(
|
||||
(
|
||||
self._prepare_metadata(cast(dict, v))
|
||||
if isinstance(v, MutableMapping)
|
||||
else copy.deepcopy(v)
|
||||
)
|
||||
for v in value
|
||||
)
|
||||
elif isinstance(value, BaseModel):
|
||||
new_metadata[key] = value.model_dump()
|
||||
elif self.is_base_type(value):
|
||||
new_metadata[key] = value
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Langfuse Layer Error - Unsupported metadata type: {type(value)} for key: {key}"
|
||||
)
|
||||
continue
|
||||
|
||||
except (TypeError, copy.Error):
|
||||
verbose_logger.debug(
|
||||
f"Langfuse Layer Error - Couldn't copy metadata key: {key}, type of key: {type(key)}, type of value: {type(value)} - {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
return new_metadata
|
||||
|
||||
def _log_langfuse_v2( # noqa: PLR0915
|
||||
self,
|
||||
user_id,
|
||||
|
@ -373,40 +441,19 @@ class LangFuseLogger:
|
|||
) -> tuple:
|
||||
import langfuse
|
||||
|
||||
print_verbose("Langfuse Layer Logging - logging to langfuse v2")
|
||||
|
||||
try:
|
||||
tags = []
|
||||
try:
|
||||
optional_params.pop("metadata")
|
||||
metadata = copy.deepcopy(
|
||||
metadata
|
||||
) # Avoid modifying the original metadata
|
||||
except Exception:
|
||||
new_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
or isinstance(value, str)
|
||||
or isinstance(value, int)
|
||||
or isinstance(value, float)
|
||||
):
|
||||
new_metadata[key] = copy.deepcopy(value)
|
||||
elif isinstance(value, BaseModel):
|
||||
new_metadata[key] = value.model_dump()
|
||||
metadata = new_metadata
|
||||
metadata = self._prepare_metadata(metadata)
|
||||
|
||||
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
||||
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||
supports_completion_start_time = Version(
|
||||
langfuse.version.__version__
|
||||
) >= Version("2.7.3")
|
||||
langfuse_version = Version(langfuse.version.__version__)
|
||||
|
||||
print_verbose("Langfuse Layer Logging - logging to langfuse v2 ")
|
||||
supports_tags = langfuse_version >= Version("2.6.3")
|
||||
supports_prompt = langfuse_version >= Version("2.7.3")
|
||||
supports_costs = langfuse_version >= Version("2.7.3")
|
||||
supports_completion_start_time = langfuse_version >= Version("2.7.3")
|
||||
|
||||
if supports_tags:
|
||||
metadata_tags = metadata.pop("tags", [])
|
||||
tags = metadata_tags
|
||||
tags = metadata.pop("tags", []) if supports_tags else []
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
|
|
|
@ -23,34 +23,8 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class LangsmithInputs(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: Optional[List[Any]] = None
|
||||
stream: Optional[bool] = None
|
||||
call_type: Optional[str] = None
|
||||
litellm_call_id: Optional[str] = None
|
||||
completion_start_time: Optional[datetime] = None
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
custom_llm_provider: Optional[str] = None
|
||||
input: Optional[List[Any]] = None
|
||||
log_event_type: Optional[str] = None
|
||||
original_response: Optional[Any] = None
|
||||
response_cost: Optional[float] = None
|
||||
|
||||
# LiteLLM Virtual Key specific fields
|
||||
user_api_key: Optional[str] = None
|
||||
user_api_key_user_id: Optional[str] = None
|
||||
user_api_key_team_alias: Optional[str] = None
|
||||
|
||||
|
||||
class LangsmithCredentialsObject(TypedDict):
|
||||
LANGSMITH_API_KEY: str
|
||||
LANGSMITH_PROJECT: str
|
||||
LANGSMITH_BASE_URL: str
|
||||
from litellm.types.integrations.langsmith import *
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
|
||||
def is_serializable(value):
|
||||
|
@ -93,15 +67,16 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
)
|
||||
if _batch_size:
|
||||
self.batch_size = int(_batch_size)
|
||||
self.log_queue: List[LangsmithQueueObject] = []
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
|
||||
def get_credentials_from_env(
|
||||
self,
|
||||
langsmith_api_key: Optional[str],
|
||||
langsmith_project: Optional[str],
|
||||
langsmith_base_url: Optional[str],
|
||||
langsmith_api_key: Optional[str] = None,
|
||||
langsmith_project: Optional[str] = None,
|
||||
langsmith_base_url: Optional[str] = None,
|
||||
) -> LangsmithCredentialsObject:
|
||||
|
||||
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
|
||||
|
@ -132,42 +107,19 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
LANGSMITH_PROJECT=_credentials_project,
|
||||
)
|
||||
|
||||
def _prepare_log_data( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
def _prepare_log_data(
|
||||
self,
|
||||
kwargs,
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
credentials: LangsmithCredentialsObject,
|
||||
):
|
||||
import json
|
||||
from datetime import datetime as dt
|
||||
|
||||
try:
|
||||
_litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
metadata = _litellm_params.get("metadata", {}) or {}
|
||||
new_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
isinstance(value, list)
|
||||
or isinstance(value, str)
|
||||
or isinstance(value, int)
|
||||
or isinstance(value, float)
|
||||
):
|
||||
new_metadata[key] = value
|
||||
elif isinstance(value, BaseModel):
|
||||
new_metadata[key] = value.model_dump_json()
|
||||
elif isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
if isinstance(v, dt):
|
||||
value[k] = v.isoformat()
|
||||
new_metadata[key] = value
|
||||
|
||||
metadata = new_metadata
|
||||
|
||||
kwargs["user_api_key"] = metadata.get("user_api_key", None)
|
||||
kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None)
|
||||
kwargs["user_api_key_team_alias"] = metadata.get(
|
||||
"user_api_key_team_alias", None
|
||||
)
|
||||
|
||||
project_name = metadata.get(
|
||||
"project_name", self.default_credentials["LANGSMITH_PROJECT"]
|
||||
"project_name", credentials["LANGSMITH_PROJECT"]
|
||||
)
|
||||
run_name = metadata.get("run_name", self.langsmith_default_run_name)
|
||||
run_id = metadata.get("id", None)
|
||||
|
@ -175,16 +127,10 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
trace_id = metadata.get("trace_id", None)
|
||||
session_id = metadata.get("session_id", None)
|
||||
dotted_order = metadata.get("dotted_order", None)
|
||||
tags = metadata.get("tags", []) or []
|
||||
verbose_logger.debug(
|
||||
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
|
||||
)
|
||||
|
||||
# filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs
|
||||
# logged_kwargs = LangsmithInputs(**kwargs)
|
||||
# kwargs = logged_kwargs.model_dump()
|
||||
|
||||
# new_kwargs = {}
|
||||
# Ensure everything in the payload is converted to str
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
|
@ -193,7 +139,6 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
if payload is None:
|
||||
raise Exception("Error logging request payload. Payload=none.")
|
||||
|
||||
new_kwargs = payload
|
||||
metadata = payload[
|
||||
"metadata"
|
||||
] # ensure logged metadata is json serializable
|
||||
|
@ -201,12 +146,12 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
data = {
|
||||
"name": run_name,
|
||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||
"inputs": new_kwargs,
|
||||
"outputs": new_kwargs["response"],
|
||||
"inputs": payload,
|
||||
"outputs": payload["response"],
|
||||
"session_name": project_name,
|
||||
"start_time": new_kwargs["startTime"],
|
||||
"end_time": new_kwargs["endTime"],
|
||||
"tags": tags,
|
||||
"start_time": payload["startTime"],
|
||||
"end_time": payload["endTime"],
|
||||
"tags": payload["request_tags"],
|
||||
"extra": metadata,
|
||||
}
|
||||
|
||||
|
@ -243,37 +188,6 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
except Exception:
|
||||
raise
|
||||
|
||||
def _send_batch(self):
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
|
||||
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
|
||||
|
||||
url = f"{langsmith_api_base}/runs/batch"
|
||||
|
||||
headers = {"x-api-key": langsmith_api_key}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=url,
|
||||
json=self.log_queue,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
verbose_logger.error(
|
||||
f"Langsmith Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Batch of {len(self.log_queue)} runs successfully created"
|
||||
)
|
||||
|
||||
self.log_queue.clear()
|
||||
except Exception:
|
||||
verbose_logger.exception("Langsmith Layer Error - Error sending batch.")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
sampling_rate = (
|
||||
|
@ -295,8 +209,20 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
|
||||
data = self._prepare_log_data(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
credentials=credentials,
|
||||
)
|
||||
self.log_queue.append(
|
||||
LangsmithQueueObject(
|
||||
data=data,
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
|
||||
)
|
||||
|
@ -323,8 +249,20 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
kwargs,
|
||||
response_obj,
|
||||
)
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
|
||||
data = self._prepare_log_data(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
credentials=credentials,
|
||||
)
|
||||
self.log_queue.append(
|
||||
LangsmithQueueObject(
|
||||
data=data,
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
|
@ -349,8 +287,20 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
return # Skip logging
|
||||
verbose_logger.info("Langsmith Failure Event Logging!")
|
||||
try:
|
||||
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
|
||||
self.log_queue.append(data)
|
||||
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
|
||||
data = self._prepare_log_data(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
credentials=credentials,
|
||||
)
|
||||
self.log_queue.append(
|
||||
LangsmithQueueObject(
|
||||
data=data,
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Langsmith logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
|
@ -365,31 +315,58 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
sends runs to /batch endpoint
|
||||
Handles sending batches of runs to Langsmith
|
||||
|
||||
Sends runs from self.log_queue
|
||||
self.log_queue contains LangsmithQueueObjects
|
||||
Each LangsmithQueueObject has the following:
|
||||
- "credentials" - credentials to use for the request (langsmith_api_key, langsmith_project, langsmith_base_url)
|
||||
- "data" - data to log on to langsmith for the request
|
||||
|
||||
|
||||
This function
|
||||
- groups the queue objects by credentials
|
||||
- loops through each unique credentials and sends batches to Langsmith
|
||||
|
||||
|
||||
This was added to support key/team based logging on langsmith
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
batch_groups = self._group_batches_by_credentials()
|
||||
for batch_group in batch_groups.values():
|
||||
await self._log_batch_on_langsmith(
|
||||
credentials=batch_group.credentials,
|
||||
queue_objects=batch_group.queue_objects,
|
||||
)
|
||||
|
||||
async def _log_batch_on_langsmith(
|
||||
self,
|
||||
credentials: LangsmithCredentialsObject,
|
||||
queue_objects: List[LangsmithQueueObject],
|
||||
):
|
||||
"""
|
||||
Logs a batch of runs to Langsmith
|
||||
sends runs to /batch endpoint for the given credentials
|
||||
|
||||
Args:
|
||||
credentials: LangsmithCredentialsObject
|
||||
queue_objects: List[LangsmithQueueObject]
|
||||
|
||||
Returns: None
|
||||
|
||||
Raises: Does not raise an exception, will only verbose_logger.exception()
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
|
||||
|
||||
langsmith_api_base = credentials["LANGSMITH_BASE_URL"]
|
||||
langsmith_api_key = credentials["LANGSMITH_API_KEY"]
|
||||
url = f"{langsmith_api_base}/runs/batch"
|
||||
|
||||
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
|
||||
|
||||
headers = {"x-api-key": langsmith_api_key}
|
||||
elements_to_log = [queue_object["data"] for queue_object in queue_objects]
|
||||
|
||||
try:
|
||||
response = await self.async_httpx_client.post(
|
||||
url=url,
|
||||
json={
|
||||
"post": self.log_queue,
|
||||
},
|
||||
json={"post": elements_to_log},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
@ -411,6 +388,74 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
f"Langsmith Layer Error - {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _group_batches_by_credentials(self) -> Dict[CredentialsKey, BatchGroup]:
|
||||
"""Groups queue objects by credentials using a proper key structure"""
|
||||
log_queue_by_credentials: Dict[CredentialsKey, BatchGroup] = {}
|
||||
|
||||
for queue_object in self.log_queue:
|
||||
credentials = queue_object["credentials"]
|
||||
key = CredentialsKey(
|
||||
api_key=credentials["LANGSMITH_API_KEY"],
|
||||
project=credentials["LANGSMITH_PROJECT"],
|
||||
base_url=credentials["LANGSMITH_BASE_URL"],
|
||||
)
|
||||
|
||||
if key not in log_queue_by_credentials:
|
||||
log_queue_by_credentials[key] = BatchGroup(
|
||||
credentials=credentials, queue_objects=[]
|
||||
)
|
||||
|
||||
log_queue_by_credentials[key].queue_objects.append(queue_object)
|
||||
|
||||
return log_queue_by_credentials
|
||||
|
||||
def _get_credentials_to_use_for_request(
|
||||
self, kwargs: Dict[str, Any]
|
||||
) -> LangsmithCredentialsObject:
|
||||
"""
|
||||
Handles key/team based logging
|
||||
|
||||
If standard_callback_dynamic_params are provided, use those credentials.
|
||||
|
||||
Otherwise, use the default credentials.
|
||||
"""
|
||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||
kwargs.get("standard_callback_dynamic_params", None)
|
||||
)
|
||||
if standard_callback_dynamic_params is not None:
|
||||
credentials = self.get_credentials_from_env(
|
||||
langsmith_api_key=standard_callback_dynamic_params.get(
|
||||
"langsmith_api_key", None
|
||||
),
|
||||
langsmith_project=standard_callback_dynamic_params.get(
|
||||
"langsmith_project", None
|
||||
),
|
||||
langsmith_base_url=standard_callback_dynamic_params.get(
|
||||
"langsmith_base_url", None
|
||||
),
|
||||
)
|
||||
else:
|
||||
credentials = self.default_credentials
|
||||
return credentials
|
||||
|
||||
def _send_batch(self):
|
||||
"""Calls async_send_batch in an event loop"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
try:
|
||||
# Try to get the existing event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If we're already in an event loop, create a task
|
||||
asyncio.create_task(self.async_send_batch())
|
||||
else:
|
||||
# If no event loop is running, run the coroutine directly
|
||||
loop.run_until_complete(self.async_send_batch())
|
||||
except RuntimeError:
|
||||
# If we can't get an event loop, create a new one
|
||||
asyncio.run(self.async_send_batch())
|
||||
|
||||
def get_run_by_id(self, run_id):
|
||||
|
||||
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
|
||||
|
|
|
@ -2,20 +2,23 @@ import os
|
|||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
EmbeddingResponse,
|
||||
Function,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.sdk.trace.export import SpanExporter as _SpanExporter
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy._types import (
|
||||
|
@ -24,10 +27,12 @@ if TYPE_CHECKING:
|
|||
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
||||
|
||||
Span = _Span
|
||||
SpanExporter = _SpanExporter
|
||||
UserAPIKeyAuth = _UserAPIKeyAuth
|
||||
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
|
||||
else:
|
||||
Span = Any
|
||||
SpanExporter = Any
|
||||
UserAPIKeyAuth = Any
|
||||
ManagementEndpointLoggingPayload = Any
|
||||
|
||||
|
@ -44,7 +49,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
|
|||
|
||||
@dataclass
|
||||
class OpenTelemetryConfig:
|
||||
from opentelemetry.sdk.trace.export import SpanExporter
|
||||
|
||||
exporter: Union[str, SpanExporter] = "console"
|
||||
endpoint: Optional[str] = None
|
||||
|
@ -77,7 +81,7 @@ class OpenTelemetryConfig:
|
|||
class OpenTelemetry(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenTelemetryConfig = OpenTelemetryConfig.from_env(),
|
||||
config: Optional[OpenTelemetryConfig] = None,
|
||||
callback_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -85,6 +89,9 @@ class OpenTelemetry(CustomLogger):
|
|||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
|
||||
if config is None:
|
||||
config = OpenTelemetryConfig.from_env()
|
||||
|
||||
self.config = config
|
||||
self.OTEL_EXPORTER = self.config.exporter
|
||||
self.OTEL_ENDPOINT = self.config.endpoint
|
||||
|
@ -281,21 +288,6 @@ class OpenTelemetry(CustomLogger):
|
|||
# End Parent OTEL Sspan
|
||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
):
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||
if parent_otel_span is not None:
|
||||
parent_otel_span.set_status(Status(StatusCode.OK))
|
||||
# End Parent OTEL Sspan
|
||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
@ -334,8 +326,8 @@ class OpenTelemetry(CustomLogger):
|
|||
|
||||
span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
# if parent_otel_span is not None:
|
||||
# parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
if parent_otel_span is not None:
|
||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
@ -413,6 +405,28 @@ class OpenTelemetry(CustomLogger):
|
|||
except Exception:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _tool_calls_kv_pair(
|
||||
tool_calls: List[ChatCompletionMessageToolCall],
|
||||
) -> Dict[str, Any]:
|
||||
from litellm.proxy._types import SpanAttributes
|
||||
|
||||
kv_pairs: Dict[str, Any] = {}
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
_function = tool_call.get("function")
|
||||
if not _function:
|
||||
continue
|
||||
|
||||
keys = Function.__annotations__.keys()
|
||||
for key in keys:
|
||||
_value = _function.get(key)
|
||||
if _value:
|
||||
kv_pairs[
|
||||
f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.function_call.{key}"
|
||||
] = _value
|
||||
|
||||
return kv_pairs
|
||||
|
||||
def set_attributes( # noqa: PLR0915
|
||||
self, span: Span, kwargs, response_obj: Optional[Any]
|
||||
):
|
||||
|
@ -607,18 +621,13 @@ class OpenTelemetry(CustomLogger):
|
|||
message = choice.get("message")
|
||||
tool_calls = message.get("tool_calls")
|
||||
if tool_calls:
|
||||
self.safe_set_attribute(
|
||||
span=span,
|
||||
key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.function_call.name",
|
||||
value=tool_calls[0].get("function").get("name"),
|
||||
)
|
||||
self.safe_set_attribute(
|
||||
span=span,
|
||||
key=f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.function_call.arguments",
|
||||
value=tool_calls[0]
|
||||
.get("function")
|
||||
.get("arguments"),
|
||||
)
|
||||
kv_pairs = OpenTelemetry._tool_calls_kv_pair(tool_calls) # type: ignore
|
||||
for key, value in kv_pairs.items():
|
||||
self.safe_set_attribute(
|
||||
span=span,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
|
@ -715,10 +724,10 @@ class OpenTelemetry(CustomLogger):
|
|||
TraceContextTextMapPropagator,
|
||||
)
|
||||
|
||||
verbose_logger.debug("OpenTelemetry: GOT A TRACEPARENT {}".format(_traceparent))
|
||||
propagator = TraceContextTextMapPropagator()
|
||||
_parent_context = propagator.extract(carrier={"traceparent": _traceparent})
|
||||
verbose_logger.debug("OpenTelemetry: PARENT CONTEXT {}".format(_parent_context))
|
||||
carrier = {"traceparent": _traceparent}
|
||||
_parent_context = propagator.extract(carrier=carrier)
|
||||
|
||||
return _parent_context
|
||||
|
||||
def _get_span_context(self, kwargs):
|
||||
|
|
|
@ -9,6 +9,7 @@ import subprocess
|
|||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import dotenv
|
||||
import requests # type: ignore
|
||||
|
@ -51,7 +52,9 @@ class PrometheusServicesLogger:
|
|||
for service in self.services:
|
||||
histogram = self.create_histogram(service, type_of_request="latency")
|
||||
counter_failed_request = self.create_counter(
|
||||
service, type_of_request="failed_requests"
|
||||
service,
|
||||
type_of_request="failed_requests",
|
||||
additional_labels=["error_class", "function_name"],
|
||||
)
|
||||
counter_total_requests = self.create_counter(
|
||||
service, type_of_request="total_requests"
|
||||
|
@ -99,7 +102,12 @@ class PrometheusServicesLogger:
|
|||
buckets=LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
def create_counter(self, service: str, type_of_request: str):
|
||||
def create_counter(
|
||||
self,
|
||||
service: str,
|
||||
type_of_request: str,
|
||||
additional_labels: Optional[List[str]] = None,
|
||||
):
|
||||
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
||||
is_registered = self.is_metric_registered(metric_name)
|
||||
if is_registered:
|
||||
|
@ -107,7 +115,7 @@ class PrometheusServicesLogger:
|
|||
return self.Counter(
|
||||
metric_name,
|
||||
"Total {} for {} service".format(type_of_request, service),
|
||||
labelnames=[service],
|
||||
labelnames=[service] + (additional_labels or []),
|
||||
)
|
||||
|
||||
def observe_histogram(
|
||||
|
@ -125,10 +133,14 @@ class PrometheusServicesLogger:
|
|||
counter,
|
||||
labels: str,
|
||||
amount: float,
|
||||
additional_labels: Optional[List[str]] = [],
|
||||
):
|
||||
assert isinstance(counter, self.Counter)
|
||||
|
||||
counter.labels(labels).inc(amount)
|
||||
if additional_labels:
|
||||
counter.labels(labels, *additional_labels).inc(amount)
|
||||
else:
|
||||
counter.labels(labels).inc(amount)
|
||||
|
||||
def service_success_hook(self, payload: ServiceLoggerPayload):
|
||||
if self.mock_testing:
|
||||
|
@ -187,16 +199,25 @@ class PrometheusServicesLogger:
|
|||
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
||||
)
|
||||
|
||||
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Union[str, Exception],
|
||||
):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_failure_calls += 1
|
||||
error_class = error.__class__.__name__
|
||||
function_name = payload.call_type
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
# increment both failed and total requests
|
||||
if isinstance(obj, self.Counter):
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
# log additional_labels=["error_class", "function_name"], used for debugging what's going wrong with the DB
|
||||
additional_labels=[error_class, function_name],
|
||||
amount=1, # LOG ERROR COUNT TO PROMETHEUS
|
||||
)
|
||||
|
|
11
litellm/litellm_core_utils/README.md
Normal file
11
litellm/litellm_core_utils/README.md
Normal file
|
@ -0,0 +1,11 @@
|
|||
## Folder Contents
|
||||
|
||||
This folder contains general-purpose utilities that are used in multiple places in the codebase.
|
||||
|
||||
Core files:
|
||||
- `streaming_handler.py`: The core streaming logic + streaming related helper utils
|
||||
- `core_helpers.py`: code used in `types/` - e.g. `map_finish_reason`.
|
||||
- `exception_mapping_utils.py`: utils for mapping exceptions to openai-compatible error types.
|
||||
- `default_encoding.py`: code for loading the default encoding (tiktoken)
|
||||
- `get_llm_provider_logic.py`: code for inferring the LLM provider from a given model name.
|
||||
|
|
@ -3,6 +3,8 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -99,3 +101,28 @@ def _get_parent_otel_span_from_kwargs(
|
|||
"Error in _get_parent_otel_span_from_kwargs: " + str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> dict:
|
||||
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
|
||||
|
||||
openai_headers = {}
|
||||
processed_headers = {}
|
||||
additional_headers = {}
|
||||
|
||||
for k, v in response_headers.items():
|
||||
if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers
|
||||
openai_headers[k] = v
|
||||
if k.startswith(
|
||||
"llm_provider-"
|
||||
): # return raw provider headers (incl. openai-compatible ones)
|
||||
processed_headers[k] = v
|
||||
else:
|
||||
additional_headers["{}-{}".format("llm_provider", k)] = v
|
||||
|
||||
additional_headers = {
|
||||
**openai_headers,
|
||||
**processed_headers,
|
||||
**additional_headers,
|
||||
}
|
||||
return additional_headers
|
||||
|
|
21
litellm/litellm_core_utils/default_encoding.py
Normal file
21
litellm/litellm_core_utils/default_encoding.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import os
|
||||
|
||||
import litellm
|
||||
|
||||
try:
|
||||
# New and recommended way to access resources
|
||||
from importlib import resources
|
||||
|
||||
filename = str(resources.files(litellm).joinpath("llms/tokenizers"))
|
||||
except (ImportError, AttributeError):
|
||||
# Old way to access resources, which setuptools deprecated some time ago
|
||||
import pkg_resources # type: ignore
|
||||
|
||||
filename = pkg_resources.resource_filename(__name__, "llms/tokenizers")
|
||||
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
|
||||
"CUSTOM_TIKTOKEN_CACHE_DIR", filename
|
||||
) # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
|
||||
import tiktoken
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
|
@ -612,19 +612,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
url="https://api.replicate.com/v1/deployments",
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
if "token_quota_reached" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"WatsonxException: Rate Limit Errror - {error_str}",
|
||||
llm_provider="watsonx",
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "predibase"
|
||||
or custom_llm_provider == "databricks"
|
||||
):
|
||||
elif custom_llm_provider in litellm._openai_like_providers:
|
||||
if "authorization denied for" in error_str:
|
||||
exception_mapping_worked = True
|
||||
|
||||
|
@ -646,6 +634,24 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
response=original_exception.response,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif "token_quota_reached" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif (
|
||||
"The server received an invalid response from an upstream server."
|
||||
in error_str
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise litellm.InternalServerError(
|
||||
message=f"{custom_llm_provider}Exception - {original_exception.message}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
)
|
||||
elif hasattr(original_exception, "status_code"):
|
||||
if original_exception.status_code == 500:
|
||||
exception_mapping_worked = True
|
||||
|
|
|
@ -226,7 +226,7 @@ def get_llm_provider( # noqa: PLR0915
|
|||
## openrouter
|
||||
elif model in litellm.openrouter_models:
|
||||
custom_llm_provider = "openrouter"
|
||||
## openrouter
|
||||
## maritalk
|
||||
elif model in litellm.maritalk_models:
|
||||
custom_llm_provider = "maritalk"
|
||||
## vertex - text + chat + language (gemini) models
|
||||
|
|
288
litellm/litellm_core_utils/get_supported_openai_params.py
Normal file
288
litellm/litellm_core_utils/get_supported_openai_params.py
Normal file
|
@ -0,0 +1,288 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
|
||||
|
||||
def get_supported_openai_params( # noqa: PLR0915
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
request_type: Literal["chat_completion", "embeddings"] = "chat_completion",
|
||||
) -> Optional[list]:
|
||||
"""
|
||||
Returns the supported openai params for a given model + provider
|
||||
|
||||
Example:
|
||||
```
|
||||
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
|
||||
```
|
||||
|
||||
Returns:
|
||||
- List if custom_llm_provider is mapped
|
||||
- None if unmapped
|
||||
"""
|
||||
if not custom_llm_provider:
|
||||
try:
|
||||
custom_llm_provider = litellm.get_llm_provider(model=model)[1]
|
||||
except BadRequestError:
|
||||
return None
|
||||
if custom_llm_provider == "bedrock":
|
||||
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ollama":
|
||||
return litellm.OllamaConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "ollama_chat":
|
||||
return litellm.OllamaChatConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "anthropic":
|
||||
return litellm.AnthropicConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
if request_type == "embeddings":
|
||||
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.nvidiaNimConfig.get_supported_openai_params(model=model)
|
||||
elif request_type == "embeddings":
|
||||
return litellm.nvidiaNimEmbeddingConfig.get_supported_openai_params()
|
||||
elif custom_llm_provider == "cerebras":
|
||||
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "xai":
|
||||
return litellm.XAIChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ai21_chat":
|
||||
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "volcengine":
|
||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "groq":
|
||||
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "hosted_vllm":
|
||||
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "deepseek":
|
||||
return [
|
||||
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
||||
"frequency_penalty",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
"stop",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "cohere":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "maritalk":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "openai":
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "azure":
|
||||
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||
return litellm.AzureOpenAIO1Config().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "openrouter":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"repetition_penalty",
|
||||
"seed",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||
# mistal and codestral api have the exact same params
|
||||
if request_type == "chat_completion":
|
||||
return litellm.MistralConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
return litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "replicate":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"functions",
|
||||
"function_call",
|
||||
]
|
||||
elif custom_llm_provider == "huggingface":
|
||||
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "together_ai":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
]
|
||||
elif custom_llm_provider == "ai21":
|
||||
return [
|
||||
"stream",
|
||||
"n",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "databricks":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.DatabricksConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
if request_type == "chat_completion":
|
||||
if model.startswith("meta/"):
|
||||
return litellm.VertexAILlama3Config().get_supported_openai_params()
|
||||
if model.startswith("mistral"):
|
||||
return litellm.MistralConfig().get_supported_openai_params()
|
||||
if model.startswith("codestral"):
|
||||
return (
|
||||
litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||
)
|
||||
if model.startswith("claude"):
|
||||
return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
|
||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.VertexGeminiConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
return [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"top_p",
|
||||
"temperature",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "cloudflare":
|
||||
return ["max_tokens", "stream"]
|
||||
elif custom_llm_provider == "nlp_cloud":
|
||||
return [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "petals":
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
elif custom_llm_provider == "deepinfra":
|
||||
return litellm.DeepInfraConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "perplexity":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "anyscale":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "watsonx":
|
||||
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "custom_openai" or "text-completion-openai":
|
||||
return [
|
||||
"functions",
|
||||
"function_call",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers",
|
||||
]
|
||||
return None
|
|
@ -2474,6 +2474,14 @@ class StandardLoggingPayloadSetup:
|
|||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Convert datetime objects to floats
|
||||
|
||||
Args:
|
||||
start_time: Union[dt_object, float]
|
||||
end_time: Union[dt_object, float]
|
||||
completion_start_time: Union[dt_object, float]
|
||||
|
||||
Returns:
|
||||
Tuple[float, float, float]: A tuple containing the start time, end time, and completion start time as floats.
|
||||
"""
|
||||
|
||||
if isinstance(start_time, datetime.datetime):
|
||||
|
@ -2534,13 +2542,10 @@ class StandardLoggingPayloadSetup:
|
|||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
clean_metadata = StandardLoggingMetadata(
|
||||
**{ # type: ignore
|
||||
key: metadata[key]
|
||||
for key in StandardLoggingMetadata.__annotations__.keys()
|
||||
if key in metadata
|
||||
}
|
||||
)
|
||||
supported_keys = StandardLoggingMetadata.__annotations__.keys()
|
||||
for key in supported_keys:
|
||||
if key in metadata:
|
||||
clean_metadata[key] = metadata[key] # type: ignore
|
||||
|
||||
if metadata.get("user_api_key") is not None:
|
||||
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
|
||||
|
@ -2769,11 +2774,6 @@ def get_standard_logging_object_payload(
|
|||
metadata=metadata
|
||||
)
|
||||
|
||||
if litellm.cache is not None:
|
||||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
saved_cache_cost: float = 0.0
|
||||
if cache_hit is True:
|
||||
|
||||
|
@ -2815,7 +2815,7 @@ def get_standard_logging_object_payload(
|
|||
completionStartTime=completion_start_time_float,
|
||||
model=kwargs.get("model", "") or "",
|
||||
metadata=clean_metadata,
|
||||
cache_key=cache_key,
|
||||
cache_key=clean_hidden_params["cache_key"],
|
||||
response_cost=response_cost,
|
||||
total_tokens=usage.total_tokens,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
|
|
|
@ -14,11 +14,17 @@ from litellm.types.utils import (
|
|||
Delta,
|
||||
EmbeddingResponse,
|
||||
Function,
|
||||
HiddenParams,
|
||||
ImageResponse,
|
||||
)
|
||||
from litellm.types.utils import Logprobs as TextCompletionLogprobs
|
||||
from litellm.types.utils import (
|
||||
Message,
|
||||
ModelResponse,
|
||||
RerankResponse,
|
||||
StreamingChoices,
|
||||
TextChoices,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
Usage,
|
||||
)
|
||||
|
@ -235,6 +241,77 @@ class LiteLLMResponseObjectHandler:
|
|||
model_response_object = ImageResponse(**model_response_dict)
|
||||
return model_response_object
|
||||
|
||||
@staticmethod
|
||||
def convert_chat_to_text_completion(
|
||||
response: ModelResponse,
|
||||
text_completion_response: TextCompletionResponse,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> TextCompletionResponse:
|
||||
"""
|
||||
Converts a chat completion response to a text completion response format.
|
||||
|
||||
Note: This is used for huggingface. For OpenAI / Azure Text the providers files directly return TextCompletionResponse which we then send to user
|
||||
|
||||
Args:
|
||||
response (ModelResponse): The chat completion response to convert
|
||||
|
||||
Returns:
|
||||
TextCompletionResponse: The converted text completion response
|
||||
|
||||
Example:
|
||||
chat_response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hi"}])
|
||||
text_response = convert_chat_to_text_completion(chat_response)
|
||||
"""
|
||||
transformed_logprobs = LiteLLMResponseObjectHandler._convert_provider_response_logprobs_to_text_completion_logprobs(
|
||||
response=response,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
text_completion_response["id"] = response.get("id", None)
|
||||
text_completion_response["object"] = "text_completion"
|
||||
text_completion_response["created"] = response.get("created", None)
|
||||
text_completion_response["model"] = response.get("model", None)
|
||||
choices_list: List[TextChoices] = []
|
||||
|
||||
# Convert each choice to TextChoices
|
||||
for choice in response["choices"]:
|
||||
text_choices = TextChoices()
|
||||
text_choices["text"] = choice["message"]["content"]
|
||||
text_choices["index"] = choice["index"]
|
||||
text_choices["logprobs"] = transformed_logprobs
|
||||
text_choices["finish_reason"] = choice["finish_reason"]
|
||||
choices_list.append(text_choices)
|
||||
|
||||
text_completion_response["choices"] = choices_list
|
||||
text_completion_response["usage"] = response.get("usage", None)
|
||||
text_completion_response._hidden_params = HiddenParams(
|
||||
**response._hidden_params
|
||||
)
|
||||
return text_completion_response
|
||||
|
||||
@staticmethod
|
||||
def _convert_provider_response_logprobs_to_text_completion_logprobs(
|
||||
response: ModelResponse,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Optional[TextCompletionLogprobs]:
|
||||
"""
|
||||
Convert logprobs from provider to OpenAI.Completion() format
|
||||
|
||||
Only supported for HF TGI models
|
||||
"""
|
||||
transformed_logprobs: Optional[TextCompletionLogprobs] = None
|
||||
if custom_llm_provider == "huggingface":
|
||||
# only supported for TGI models
|
||||
try:
|
||||
raw_response = response._hidden_params.get("original_response", None)
|
||||
transformed_logprobs = litellm.huggingface._transform_logprobs(
|
||||
hf_response=raw_response
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
|
||||
|
||||
return transformed_logprobs
|
||||
|
||||
|
||||
def convert_to_model_response_object( # noqa: PLR0915
|
||||
response_object: Optional[dict] = None,
|
||||
|
|
50
litellm/litellm_core_utils/rules.py
Normal file
50
litellm/litellm_core_utils/rules.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class Rules:
|
||||
"""
|
||||
Fail calls based on the input or llm api output
|
||||
|
||||
Example usage:
|
||||
import litellm
|
||||
def my_custom_rule(input): # receives the model response
|
||||
if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer
|
||||
return False
|
||||
return True
|
||||
|
||||
litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call
|
||||
|
||||
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user",
|
||||
"content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"])
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def pre_call_rules(self, input: str, model: str):
|
||||
for rule in litellm.pre_call_rules:
|
||||
if callable(rule):
|
||||
decision = rule(input)
|
||||
if decision is False:
|
||||
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
|
||||
return True
|
||||
|
||||
def post_call_rules(self, input: Optional[str], model: str) -> bool:
|
||||
if input is None:
|
||||
return True
|
||||
for rule in litellm.post_call_rules:
|
||||
if callable(rule):
|
||||
decision = rule(input)
|
||||
if isinstance(decision, bool):
|
||||
if decision is False:
|
||||
raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore
|
||||
elif isinstance(decision, dict):
|
||||
decision_val = decision.get("decision", True)
|
||||
decision_message = decision.get(
|
||||
"message", "LLM Response failed post-call-rule check"
|
||||
)
|
||||
if decision_val is False:
|
||||
raise litellm.APIResponseValidationError(message=decision_message, llm_provider="", model=model) # type: ignore
|
||||
return True
|
|
@ -243,6 +243,49 @@ class ChunkProcessor:
|
|||
id=id,
|
||||
)
|
||||
|
||||
def _usage_chunk_calculation_helper(self, usage_chunk: Usage) -> dict:
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
## anthropic prompt caching information ##
|
||||
cache_creation_input_tokens: Optional[int] = None
|
||||
cache_read_input_tokens: Optional[int] = None
|
||||
completion_tokens_details: Optional[CompletionTokensDetails] = None
|
||||
prompt_tokens_details: Optional[PromptTokensDetails] = None
|
||||
|
||||
if "prompt_tokens" in usage_chunk:
|
||||
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
|
||||
if "completion_tokens" in usage_chunk:
|
||||
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
|
||||
if "cache_creation_input_tokens" in usage_chunk:
|
||||
cache_creation_input_tokens = usage_chunk.get("cache_creation_input_tokens")
|
||||
if "cache_read_input_tokens" in usage_chunk:
|
||||
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
|
||||
if hasattr(usage_chunk, "completion_tokens_details"):
|
||||
if isinstance(usage_chunk.completion_tokens_details, dict):
|
||||
completion_tokens_details = CompletionTokensDetails(
|
||||
**usage_chunk.completion_tokens_details
|
||||
)
|
||||
elif isinstance(
|
||||
usage_chunk.completion_tokens_details, CompletionTokensDetails
|
||||
):
|
||||
completion_tokens_details = usage_chunk.completion_tokens_details
|
||||
if hasattr(usage_chunk, "prompt_tokens_details"):
|
||||
if isinstance(usage_chunk.prompt_tokens_details, dict):
|
||||
prompt_tokens_details = PromptTokensDetails(
|
||||
**usage_chunk.prompt_tokens_details
|
||||
)
|
||||
elif isinstance(usage_chunk.prompt_tokens_details, PromptTokensDetails):
|
||||
prompt_tokens_details = usage_chunk.prompt_tokens_details
|
||||
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"cache_creation_input_tokens": cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": cache_read_input_tokens,
|
||||
"completion_tokens_details": completion_tokens_details,
|
||||
"prompt_tokens_details": prompt_tokens_details,
|
||||
}
|
||||
|
||||
def calculate_usage(
|
||||
self,
|
||||
chunks: List[Union[Dict[str, Any], ModelResponse]],
|
||||
|
@ -269,37 +312,30 @@ class ChunkProcessor:
|
|||
elif isinstance(chunk, ModelResponse) and hasattr(chunk, "_hidden_params"):
|
||||
usage_chunk = chunk._hidden_params.get("usage", None)
|
||||
if usage_chunk is not None:
|
||||
if "prompt_tokens" in usage_chunk:
|
||||
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
|
||||
if "completion_tokens" in usage_chunk:
|
||||
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
|
||||
if "cache_creation_input_tokens" in usage_chunk:
|
||||
cache_creation_input_tokens = usage_chunk.get(
|
||||
usage_chunk_dict = self._usage_chunk_calculation_helper(usage_chunk)
|
||||
if (
|
||||
usage_chunk_dict["prompt_tokens"] is not None
|
||||
and usage_chunk_dict["prompt_tokens"] > 0
|
||||
):
|
||||
prompt_tokens = usage_chunk_dict["prompt_tokens"]
|
||||
if (
|
||||
usage_chunk_dict["completion_tokens"] is not None
|
||||
and usage_chunk_dict["completion_tokens"] > 0
|
||||
):
|
||||
completion_tokens = usage_chunk_dict["completion_tokens"]
|
||||
if usage_chunk_dict["cache_creation_input_tokens"] is not None:
|
||||
cache_creation_input_tokens = usage_chunk_dict[
|
||||
"cache_creation_input_tokens"
|
||||
)
|
||||
if "cache_read_input_tokens" in usage_chunk:
|
||||
cache_read_input_tokens = usage_chunk.get("cache_read_input_tokens")
|
||||
if hasattr(usage_chunk, "completion_tokens_details"):
|
||||
if isinstance(usage_chunk.completion_tokens_details, dict):
|
||||
completion_tokens_details = CompletionTokensDetails(
|
||||
**usage_chunk.completion_tokens_details
|
||||
)
|
||||
elif isinstance(
|
||||
usage_chunk.completion_tokens_details, CompletionTokensDetails
|
||||
):
|
||||
completion_tokens_details = (
|
||||
usage_chunk.completion_tokens_details
|
||||
)
|
||||
if hasattr(usage_chunk, "prompt_tokens_details"):
|
||||
if isinstance(usage_chunk.prompt_tokens_details, dict):
|
||||
prompt_tokens_details = PromptTokensDetails(
|
||||
**usage_chunk.prompt_tokens_details
|
||||
)
|
||||
elif isinstance(
|
||||
usage_chunk.prompt_tokens_details, PromptTokensDetails
|
||||
):
|
||||
prompt_tokens_details = usage_chunk.prompt_tokens_details
|
||||
|
||||
]
|
||||
if usage_chunk_dict["cache_read_input_tokens"] is not None:
|
||||
cache_read_input_tokens = usage_chunk_dict[
|
||||
"cache_read_input_tokens"
|
||||
]
|
||||
if usage_chunk_dict["completion_tokens_details"] is not None:
|
||||
completion_tokens_details = usage_chunk_dict[
|
||||
"completion_tokens_details"
|
||||
]
|
||||
prompt_tokens_details = usage_chunk_dict["prompt_tokens_details"]
|
||||
try:
|
||||
returned_usage.prompt_tokens = prompt_tokens or token_counter(
|
||||
model=model, messages=messages
|
||||
|
|
2020
litellm/litellm_core_utils/streaming_handler.py
Normal file
2020
litellm/litellm_core_utils/streaming_handler.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,14 +0,0 @@
|
|||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
|
||||
|
||||
def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
|
||||
"""
|
||||
Checks if the provided chunk dictionary contains all required fields for GenericStreamingChunk.
|
||||
|
||||
:param chunk: The dictionary to check.
|
||||
:return: True if all required fields are present, False otherwise.
|
||||
"""
|
||||
_all_fields = GChunk.__annotations__
|
||||
|
||||
decision = all(key in _all_fields for key in chunk)
|
||||
return decision
|
|
@ -3,7 +3,7 @@ Support for gpt model family
|
|||
"""
|
||||
|
||||
import types
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
|
@ -94,6 +94,7 @@ class OpenAIGPTConfig:
|
|||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
|
@ -162,3 +163,8 @@ class OpenAIGPTConfig:
|
|||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
return messages
|
||||
|
|
|
@ -108,7 +108,9 @@ class OpenAIO1Config(OpenAIGPTConfig):
|
|||
return True
|
||||
return False
|
||||
|
||||
def o1_prompt_factory(self, messages: List[AllMessageValues]):
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Handles limitations of O-1 model family.
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
|
|
|
@ -15,6 +15,7 @@ from pydantic import BaseModel
|
|||
from typing_extensions import overload, override
|
||||
|
||||
import litellm
|
||||
from litellm import LlmProviders
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
@ -24,6 +25,7 @@ from litellm.utils import (
|
|||
CustomStreamWrapper,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ProviderConfigManager,
|
||||
TextCompletionResponse,
|
||||
Usage,
|
||||
convert_to_model_response_object,
|
||||
|
@ -701,13 +703,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
messages=messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if (
|
||||
litellm.openAIO1Config.is_model_o1_reasoning_model(model=model)
|
||||
and messages is not None
|
||||
):
|
||||
messages = litellm.openAIO1Config.o1_prompt_factory(
|
||||
messages=messages,
|
||||
if messages is not None and custom_llm_provider is not None:
|
||||
provider_config = ProviderConfigManager.get_provider_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
messages = provider_config._transform_messages(messages)
|
||||
|
||||
for _ in range(
|
||||
2
|
||||
|
|
|
@ -71,11 +71,12 @@ def validate_environment(
|
|||
|
||||
prompt_caching_set = AnthropicConfig().is_cache_control_set(messages=messages)
|
||||
computer_tool_used = AnthropicConfig().is_computer_tool_used(tools=tools)
|
||||
|
||||
pdf_used = AnthropicConfig().is_pdf_used(messages=messages)
|
||||
headers = AnthropicConfig().get_anthropic_headers(
|
||||
anthropic_version=anthropic_version,
|
||||
computer_tool_used=computer_tool_used,
|
||||
prompt_caching_set=prompt_caching_set,
|
||||
pdf_used=pdf_used,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
@ -769,6 +770,7 @@ class ModelResponseIterator:
|
|||
message=message,
|
||||
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
||||
)
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
|
|
|
@ -104,6 +104,7 @@ class AnthropicConfig:
|
|||
anthropic_version: Optional[str] = None,
|
||||
computer_tool_used: bool = False,
|
||||
prompt_caching_set: bool = False,
|
||||
pdf_used: bool = False,
|
||||
) -> dict:
|
||||
import json
|
||||
|
||||
|
@ -112,6 +113,8 @@ class AnthropicConfig:
|
|||
betas.append("prompt-caching-2024-07-31")
|
||||
if computer_tool_used:
|
||||
betas.append("computer-use-2024-10-22")
|
||||
if pdf_used:
|
||||
betas.append("pdfs-2024-09-25")
|
||||
headers = {
|
||||
"anthropic-version": anthropic_version or "2023-06-01",
|
||||
"x-api-key": api_key,
|
||||
|
@ -365,6 +368,21 @@ class AnthropicConfig:
|
|||
return True
|
||||
return False
|
||||
|
||||
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Set to true if media passed into messages.
|
||||
"""
|
||||
for message in messages:
|
||||
if (
|
||||
"content" in message
|
||||
and message["content"] is not None
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content in message["content"]:
|
||||
if "type" in content:
|
||||
return True
|
||||
return False
|
||||
|
||||
def translate_system_message(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AnthropicSystemMessageContent]:
|
||||
|
|
|
@ -1,16 +1,28 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache, InMemoryCache
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.credentials import Credentials
|
||||
else:
|
||||
Credentials = Any
|
||||
|
||||
|
||||
class Boto3CredentialsInfo(BaseModel):
|
||||
credentials: Credentials
|
||||
aws_region_name: str
|
||||
aws_bedrock_runtime_endpoint: Optional[str]
|
||||
|
||||
|
||||
class AwsAuthError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -311,3 +323,74 @@ class BaseAWSLLM(BaseLLM):
|
|||
proxy_endpoint_url = endpoint_url
|
||||
|
||||
return endpoint_url, proxy_endpoint_url
|
||||
|
||||
def _get_boto_credentials_from_optional_params(
|
||||
self, optional_params: dict
|
||||
) -> Boto3CredentialsInfo:
|
||||
"""
|
||||
Get boto3 credentials from optional params
|
||||
|
||||
Args:
|
||||
optional_params (dict): Optional parameters for the model call
|
||||
|
||||
Returns:
|
||||
Credentials: Boto3 credentials object
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret_str("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
|
||||
return Boto3CredentialsInfo(
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@ from ..common_utils import BedrockError
|
|||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||
|
||||
BEDROCK_CONVERSE_MODELS = [
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
|
|
|
@ -484,73 +484,6 @@ class AmazonMistralConfig:
|
|||
}
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def add_custom_header(headers):
|
||||
"""Closure to capture the headers and add them."""
|
||||
|
||||
|
|
104
litellm/llms/bedrock/image/amazon_stability1_transformation.py
Normal file
104
litellm/llms/bedrock/image/amazon_stability1_transformation.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
return ["size"]
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(
|
||||
cls,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
_size = non_default_params.get("size")
|
||||
if _size is not None:
|
||||
width, height = _size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
image_list: List[Image] = []
|
||||
for artifact in response_dict["artifacts"]:
|
||||
_image = Image(b64_json=artifact["base64"])
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
|
||||
return model_response
|
|
@ -0,0 +1,94 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonStability3TextToImageRequest,
|
||||
AmazonStability3TextToImageResponse,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonStability3Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Stability API Ref: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
"""
|
||||
No additional OpenAI params are mapped for stability 3
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _is_stability_3_model(cls, model: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns True if the model is a Stability 3 model
|
||||
|
||||
Stability 3 models follow this pattern:
|
||||
sd3-large
|
||||
sd3-large-turbo
|
||||
sd3-medium
|
||||
sd3.5-large
|
||||
sd3.5-large-turbo
|
||||
"""
|
||||
if model and ("sd3" in model or "sd3.5" in model):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls, prompt: str, optional_params: dict
|
||||
) -> AmazonStability3TextToImageRequest:
|
||||
"""
|
||||
Transform the request body for the Stability 3 models
|
||||
"""
|
||||
data = AmazonStability3TextToImageRequest(prompt=prompt, **optional_params)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
|
||||
"""
|
||||
Map the OpenAI params to the Bedrock params
|
||||
|
||||
No OpenAI params are mapped for Stability 3, so directly return the optional_params
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the response dict to the OpenAI response
|
||||
"""
|
||||
|
||||
stability_3_response = AmazonStability3TextToImageResponse(**response_dict)
|
||||
openai_images: List[Image] = []
|
||||
for _img in stability_3_response.get("images", []):
|
||||
openai_images.append(Image(b64_json=_img))
|
||||
|
||||
model_response.data = openai_images
|
||||
return model_response
|
41
litellm/llms/bedrock/image/cost_calculator.py
Normal file
41
litellm/llms/bedrock/image/cost_calculator.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
size: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Bedrock image generation cost calculator
|
||||
|
||||
Handles both Stability 1 and Stability 3 models
|
||||
"""
|
||||
if litellm.AmazonStability3Config()._is_stability_3_model(model=model):
|
||||
pass
|
||||
else:
|
||||
# Stability 1 models
|
||||
optional_params = optional_params or {}
|
||||
|
||||
# see model_prices_and_context_window.json for details on how steps is used
|
||||
# Reference pricing by steps for stability 1: https://aws.amazon.com/bedrock/pricing/
|
||||
_steps = optional_params.get("steps", 50)
|
||||
steps = "max-steps" if _steps > 50 else "50-steps"
|
||||
|
||||
# size is stored in model_prices_and_context_window.json as 1024-x-1024
|
||||
# current size has 1024x1024
|
||||
size = size or "1024-x-1024"
|
||||
model = f"{size}/{steps}/{model}"
|
||||
|
||||
_model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
304
litellm/llms/bedrock/image/image_handler.py
Normal file
304
litellm/llms/bedrock/image/image_handler.py
Normal file
|
@ -0,0 +1,304 @@
|
|||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai.types.image import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from ...base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockImagePreparedRequest(BaseModel):
|
||||
"""
|
||||
Internal/Helper class for preparing the request for bedrock image generation
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
prepped: AWSPreparedRequest
|
||||
body: bytes
|
||||
data: dict
|
||||
|
||||
|
||||
class BedrockImageGeneration(BaseAWSLLM):
|
||||
"""
|
||||
Bedrock Image Generation handler
|
||||
"""
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: LitellmLogging,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
aimg_generation: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
if aimg_generation is True:
|
||||
return self.async_image_generation(
|
||||
prepared_request=prepared_request,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
)
|
||||
return model_response
|
||||
|
||||
async def async_image_generation(
|
||||
self,
|
||||
prepared_request: BedrockImagePreparedRequest,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Asynchronous handler for bedrock image generation
|
||||
|
||||
Awaits the response from the bedrock image generation endpoint
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
model_response=model_response,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
) -> BedrockImagePreparedRequest:
|
||||
"""
|
||||
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
|
||||
|
||||
Args:
|
||||
model (str): The model to use for the image generation
|
||||
optional_params (dict): The optional parameters for the image generation
|
||||
api_base (Optional[str]): The base URL for the Bedrock API
|
||||
extra_headers (Optional[dict]): The extra headers to include in the request
|
||||
logging_obj (LitellmLogging): The logging object to use for logging
|
||||
prompt (str): The prompt to use for the image generation
|
||||
Returns:
|
||||
BedrockImagePreparedRequest: The prepared request object
|
||||
|
||||
The BedrockImagePreparedRequest contains:
|
||||
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
|
||||
prepped (httpx.Request): The prepared request object
|
||||
body (bytes): The request body
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
modelId = model
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
|
||||
sigv4 = SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"bedrock",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
|
||||
data = self._get_request_body(
|
||||
model=model, prompt=prompt, optional_params=optional_params
|
||||
)
|
||||
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=proxy_endpoint_url, data=body, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
return BedrockImagePreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def _get_request_body(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the request body for the Bedrock Image Generation API
|
||||
|
||||
Checks the model/provider and transforms the request body accordingly
|
||||
|
||||
Returns:
|
||||
dict: The request body to use for the Bedrock Image Generation API
|
||||
"""
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
data = {}
|
||||
if provider == "stability":
|
||||
if litellm.AmazonStability3Config._is_stability_3_model(model):
|
||||
request_body = litellm.AmazonStability3Config.transform_request_body(
|
||||
prompt=prompt, optional_params=optional_params
|
||||
)
|
||||
return dict(request_body)
|
||||
else:
|
||||
prompt = prompt.replace(os.linesep, " ")
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonStabilityConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = {
|
||||
"text_prompts": [{"text": prompt, "weight": 1}],
|
||||
**inference_params,
|
||||
}
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||
)
|
||||
return data
|
||||
|
||||
def _transform_response_dict_to_openai_response(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
response: httpx.Response,
|
||||
data: dict,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transforms the Image Generation response from Bedrock to OpenAI format
|
||||
"""
|
||||
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
verbose_logger.debug("raw model_response: %s", response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict is None:
|
||||
raise ValueError("Error in response object format, got None")
|
||||
|
||||
config_class = (
|
||||
litellm.AmazonStability3Config
|
||||
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
|
||||
else litellm.AmazonStabilityConfig
|
||||
)
|
||||
config_class.transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
response_dict=response_dict,
|
||||
)
|
||||
|
||||
return model_response
|
|
@ -1,127 +0,0 @@
|
|||
"""
|
||||
Handles image gen calls to Bedrock's `/invoke` endpoint
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from .common_utils import BedrockError, init_bedrock_client
|
||||
|
||||
|
||||
def image_generation(
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
timeout=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
"""
|
||||
Bedrock Image Gen endpoint support
|
||||
"""
|
||||
### BOTO3 INIT ###
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
### FORMAT IMAGE GENERATION INPUT ###
|
||||
modelId = model
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
data = {}
|
||||
if provider == "stability":
|
||||
prompt = prompt.replace(os.linesep, " ")
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonStabilityConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||
)
|
||||
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={body}, # type: ignore
|
||||
modelId={modelId},
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={
|
||||
"complete_input_dict": {"model": modelId, "texts": prompt},
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
body=body,
|
||||
modelId=modelId,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message=f"Embedding Error with model {model}: {e}", status_code=500
|
||||
)
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
if response_body is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
if model_response is None:
|
||||
model_response = ImageResponse()
|
||||
|
||||
image_list: List[Image] = []
|
||||
for artifact in response_body["artifacts"]:
|
||||
_image = Image(b64_json=artifact["base64"])
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
return model_response
|
|
@ -34,12 +34,14 @@ class AsyncHTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None,
|
||||
concurrent_limit=1000,
|
||||
client_alias: Optional[str] = None, # name for client in logs
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.event_hooks = event_hooks
|
||||
self.client = self.create_client(
|
||||
timeout=timeout, concurrent_limit=concurrent_limit, event_hooks=event_hooks
|
||||
)
|
||||
self.client_alias = client_alias
|
||||
|
||||
def create_client(
|
||||
self,
|
||||
|
@ -112,6 +114,7 @@ class AsyncHTTPHandler:
|
|||
try:
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||
)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionDeltaChunk,
|
||||
ChatCompletionResponseMessage,
|
||||
|
@ -9,7 +10,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse, Usage
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
|
@ -109,7 +110,17 @@ class ModelResponseIterator:
|
|||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
verbose_logger.debug(
|
||||
f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here."
|
||||
)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
|
@ -123,6 +134,8 @@ class ModelResponseIterator:
|
|||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
chunk = chunk.replace("data:", "")
|
||||
|
@ -144,4 +157,14 @@ class ModelResponseIterator:
|
|||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
verbose_logger.debug(
|
||||
f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here."
|
||||
)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
|
41
litellm/llms/deepseek/chat/transformation.py
Normal file
41
litellm/llms/deepseek/chat/transformation.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
Translates from OpenAI's `/v1/chat/completions` to DeepSeek's `/v1/chat/completions`
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ...prompt_templates.common_utils import (
|
||||
handle_messages_with_content_list_to_str_conversion,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekChatConfig(OpenAIGPTConfig):
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
DeepSeek does not support content in list format.
|
||||
"""
|
||||
messages = handle_messages_with_content_list_to_str_conversion(messages)
|
||||
return super()._transform_messages(messages)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("DEEPSEEK_API_BASE")
|
||||
or "https://api.deepseek.com/beta"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY")
|
||||
return api_base, dynamic_api_key
|
|
@ -15,6 +15,7 @@ import litellm
|
|||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||
from litellm.types.utils import Logprobs as TextCompletionLogprobs
|
||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||
|
||||
from .base import BaseLLM
|
||||
|
@ -1183,3 +1184,73 @@ class Huggingface(BaseLLM):
|
|||
input=input,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def _transform_logprobs(
|
||||
self, hf_response: Optional[List]
|
||||
) -> Optional[TextCompletionLogprobs]:
|
||||
"""
|
||||
Transform Hugging Face logprobs to OpenAI.Completion() format
|
||||
"""
|
||||
if hf_response is None:
|
||||
return None
|
||||
|
||||
# Initialize an empty list for the transformed logprobs
|
||||
_logprob: TextCompletionLogprobs = TextCompletionLogprobs(
|
||||
text_offset=[],
|
||||
token_logprobs=[],
|
||||
tokens=[],
|
||||
top_logprobs=[],
|
||||
)
|
||||
|
||||
# For each Hugging Face response, transform the logprobs
|
||||
for response in hf_response:
|
||||
# Extract the relevant information from the response
|
||||
response_details = response["details"]
|
||||
top_tokens = response_details.get("top_tokens", {})
|
||||
|
||||
for i, token in enumerate(response_details["prefill"]):
|
||||
# Extract the text of the token
|
||||
token_text = token["text"]
|
||||
|
||||
# Extract the logprob of the token
|
||||
token_logprob = token["logprob"]
|
||||
|
||||
# Add the token information to the 'token_info' list
|
||||
_logprob.tokens.append(token_text)
|
||||
_logprob.token_logprobs.append(token_logprob)
|
||||
|
||||
# stub this to work with llm eval harness
|
||||
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
|
||||
_logprob.top_logprobs.append(top_alt_tokens)
|
||||
|
||||
# For each element in the 'tokens' list, extract the relevant information
|
||||
for i, token in enumerate(response_details["tokens"]):
|
||||
# Extract the text of the token
|
||||
token_text = token["text"]
|
||||
|
||||
# Extract the logprob of the token
|
||||
token_logprob = token["logprob"]
|
||||
|
||||
top_alt_tokens = {}
|
||||
temp_top_logprobs = []
|
||||
if top_tokens != {}:
|
||||
temp_top_logprobs = top_tokens[i]
|
||||
|
||||
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
|
||||
for elem in temp_top_logprobs:
|
||||
text = elem["text"]
|
||||
logprob = elem["logprob"]
|
||||
top_alt_tokens[text] = logprob
|
||||
|
||||
# Add the token information to the 'token_info' list
|
||||
_logprob.tokens.append(token_text)
|
||||
_logprob.token_logprobs.append(token_logprob)
|
||||
_logprob.top_logprobs.append(top_alt_tokens)
|
||||
|
||||
# Add the text offset of the token
|
||||
# This is computed as the sum of the lengths of all previous tokens
|
||||
_logprob.text_offset.append(
|
||||
sum(len(t["text"]) for t in response_details["tokens"][:i])
|
||||
)
|
||||
|
||||
return _logprob
|
||||
|
|
|
@ -10,6 +10,7 @@ import types
|
|||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class MistralConfig:
|
||||
|
@ -148,3 +149,59 @@ class MistralConfig:
|
|||
or get_secret_str("MISTRAL_API_KEY")
|
||||
)
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
@classmethod
|
||||
def _transform_messages(cls, messages: List[AllMessageValues]):
|
||||
"""
|
||||
- handles scenario where content is list and not string
|
||||
- content list is just text, and no images
|
||||
- if image passed in, then just return as is (user-intended)
|
||||
- if `name` is passed, then drop it for mistral API: https://github.com/BerriAI/litellm/issues/6696
|
||||
|
||||
Motivation: mistral api doesn't support content as a list
|
||||
"""
|
||||
new_messages = []
|
||||
for m in messages:
|
||||
special_keys = ["role", "content", "tool_calls", "function_call"]
|
||||
extra_args = {}
|
||||
if isinstance(m, dict):
|
||||
for k, v in m.items():
|
||||
if k not in special_keys:
|
||||
extra_args[k] = v
|
||||
texts = ""
|
||||
_content = m.get("content")
|
||||
if _content is not None and isinstance(_content, list):
|
||||
for c in _content:
|
||||
_text: Optional[str] = c.get("text")
|
||||
if c["type"] == "image_url":
|
||||
return messages
|
||||
elif c["type"] == "text" and isinstance(_text, str):
|
||||
texts += _text
|
||||
elif _content is not None and isinstance(_content, str):
|
||||
texts = _content
|
||||
|
||||
new_m = {"role": m["role"], "content": texts, **extra_args}
|
||||
|
||||
if m.get("tool_calls"):
|
||||
new_m["tool_calls"] = m.get("tool_calls")
|
||||
|
||||
new_m = cls._handle_name_in_message(new_m)
|
||||
|
||||
new_messages.append(new_m)
|
||||
return new_messages
|
||||
|
||||
@classmethod
|
||||
def _handle_name_in_message(cls, message: dict) -> dict:
|
||||
"""
|
||||
Mistral API only supports `name` in tool messages
|
||||
|
||||
If role == tool, then we keep `name`
|
||||
Otherwise, we drop `name`
|
||||
"""
|
||||
if message.get("name") is not None:
|
||||
if message["role"] == "tool":
|
||||
message["name"] = message.get("name")
|
||||
else:
|
||||
message.pop("name", None)
|
||||
|
||||
return message
|
||||
|
|
372
litellm/llms/openai_like/chat/handler.py
Normal file
372
litellm/llms/openai_like/chat/handler.py
Normal file
|
@ -0,0 +1,372 @@
|
|||
"""
|
||||
OpenAI-like chat completion handler
|
||||
|
||||
For handling OpenAI-like chat completions, like IBM WatsonX, etc.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
|
||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse
|
||||
|
||||
from ..common_utils import OpenAILikeBase, OpenAILikeError
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
):
|
||||
if client is None:
|
||||
client = litellm.module_level_aclient
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if streaming_decoder is not None:
|
||||
completion_stream: Any = streaming_decoder.aiter_bytes(
|
||||
response.aiter_bytes(chunk_size=1024)
|
||||
)
|
||||
else:
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
):
|
||||
if client is None:
|
||||
client = litellm.module_level_client # Create a new client if none provided
|
||||
|
||||
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OpenAILikeError(status_code=response.status_code, message=response.read())
|
||||
|
||||
if streaming_decoder is not None:
|
||||
completion_stream = streaming_decoder.iter_bytes(
|
||||
response.iter_bytes(chunk_size=1024)
|
||||
)
|
||||
else:
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class OpenAILikeChatHandler(OpenAILikeBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_llm_provider: str,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
|
||||
data["stream"] = True
|
||||
completion_stream = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streamwrapper
|
||||
|
||||
async def acompletion_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
custom_llm_provider: str,
|
||||
print_verbose: Callable,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
base_model: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> ModelResponse:
|
||||
if timeout is None:
|
||||
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
if client is None:
|
||||
client = litellm.module_level_aclient
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise OpenAILikeError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise OpenAILikeError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise OpenAILikeError(status_code=500, message=str(e))
|
||||
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
response.model = custom_llm_provider + "/" + (response.model or "")
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
return response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[
|
||||
CustomStreamingDecoder
|
||||
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
|
||||
):
|
||||
custom_endpoint = custom_endpoint or optional_params.pop(
|
||||
"custom_endpoint", None
|
||||
)
|
||||
base_model: Optional[str] = optional_params.pop("base_model", None)
|
||||
api_base, headers = self._validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="chat_completions",
|
||||
custom_endpoint=custom_endpoint,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
stream: bool = optional_params.get("stream", None) or False
|
||||
optional_params["stream"] = stream
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = None
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
base_model=base_model,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if stream is True:
|
||||
completion_stream = make_sync_call(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
# completion_stream.__iter__()
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
try:
|
||||
response = client.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise OpenAILikeError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise OpenAILikeError(
|
||||
status_code=408, message="Timeout error occurred."
|
||||
)
|
||||
except Exception as e:
|
||||
raise OpenAILikeError(status_code=500, message=str(e))
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
response.model = custom_llm_provider + "/" + (response.model or "")
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
|
||||
return response
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Literal, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
|
@ -10,3 +12,43 @@ class OpenAILikeError(Exception):
|
|||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class OpenAILikeBase:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
headers: Optional[dict],
|
||||
custom_endpoint: Optional[bool],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and headers is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if not custom_endpoint:
|
||||
if endpoint_type == "chat_completions":
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings":
|
||||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
||||
|
|
|
@ -23,46 +23,13 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.utils import EmbeddingResponse
|
||||
|
||||
from ..common_utils import OpenAILikeError
|
||||
from ..common_utils import OpenAILikeBase, OpenAILikeError
|
||||
|
||||
|
||||
class OpenAILikeEmbeddingHandler:
|
||||
class OpenAILikeEmbeddingHandler(OpenAILikeBase):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
headers: Optional[dict],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and headers is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if endpoint_type == "chat_completions":
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings":
|
||||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
input: list,
|
||||
|
@ -133,6 +100,7 @@ class OpenAILikeEmbeddingHandler:
|
|||
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
api_base, headers = self._validate_environment(
|
||||
|
@ -140,6 +108,7 @@ class OpenAILikeEmbeddingHandler:
|
|||
api_key=api_key,
|
||||
endpoint_type="embeddings",
|
||||
headers=headers,
|
||||
custom_endpoint=custom_endpoint,
|
||||
)
|
||||
model = model
|
||||
data = {"model": model, "input": input, **optional_params}
|
||||
|
|
|
@ -24,6 +24,19 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
|
|||
)
|
||||
|
||||
|
||||
def handle_messages_with_content_list_to_str_conversion(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Handles messages with content list conversion
|
||||
"""
|
||||
for message in messages:
|
||||
texts = convert_content_list_to_str(message=message)
|
||||
if texts:
|
||||
message["content"] = texts
|
||||
return messages
|
||||
|
||||
|
||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||
"""
|
||||
- handles scenario where content is list and not string
|
||||
|
|
|
@ -259,43 +259,6 @@ def mistral_instruct_pt(messages):
|
|||
return prompt
|
||||
|
||||
|
||||
def mistral_api_pt(messages):
|
||||
"""
|
||||
- handles scenario where content is list and not string
|
||||
- content list is just text, and no images
|
||||
- if image passed in, then just return as is (user-intended)
|
||||
|
||||
Motivation: mistral api doesn't support content as a list
|
||||
"""
|
||||
new_messages = []
|
||||
for m in messages:
|
||||
special_keys = ["role", "content", "tool_calls", "function_call"]
|
||||
extra_args = {}
|
||||
if isinstance(m, dict):
|
||||
for k, v in m.items():
|
||||
if k not in special_keys:
|
||||
extra_args[k] = v
|
||||
texts = ""
|
||||
if m.get("content", None) is not None and isinstance(m["content"], list):
|
||||
for c in m["content"]:
|
||||
if c["type"] == "image_url":
|
||||
return messages
|
||||
elif c["type"] == "text" and isinstance(c["text"], str):
|
||||
texts += c["text"]
|
||||
elif m.get("content", None) is not None and isinstance(m["content"], str):
|
||||
texts = m["content"]
|
||||
|
||||
new_m = {"role": m["role"], "content": texts, **extra_args}
|
||||
|
||||
if new_m["role"] == "tool" and m.get("name"):
|
||||
new_m["name"] = m["name"]
|
||||
if m.get("tool_calls"):
|
||||
new_m["tool_calls"] = m["tool_calls"]
|
||||
|
||||
new_messages.append(new_m)
|
||||
return new_messages
|
||||
|
||||
|
||||
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||
def falcon_instruct_pt(messages):
|
||||
prompt = ""
|
||||
|
@ -1330,7 +1293,10 @@ def convert_to_anthropic_tool_invoke(
|
|||
|
||||
def add_cache_control_to_content(
|
||||
anthropic_content_element: Union[
|
||||
dict, AnthropicMessagesImageParam, AnthropicMessagesTextParam
|
||||
dict,
|
||||
AnthropicMessagesImageParam,
|
||||
AnthropicMessagesTextParam,
|
||||
AnthropicMessagesDocumentParam,
|
||||
],
|
||||
orignal_content_element: Union[dict, AllMessageValues],
|
||||
):
|
||||
|
@ -1343,6 +1309,32 @@ def add_cache_control_to_content(
|
|||
return anthropic_content_element
|
||||
|
||||
|
||||
def _anthropic_content_element_factory(
|
||||
image_chunk: GenericImageParsingChunk,
|
||||
) -> Union[AnthropicMessagesImageParam, AnthropicMessagesDocumentParam]:
|
||||
if image_chunk["media_type"] == "application/pdf":
|
||||
_anthropic_content_element: Union[
|
||||
AnthropicMessagesDocumentParam, AnthropicMessagesImageParam
|
||||
] = AnthropicMessagesDocumentParam(
|
||||
type="document",
|
||||
source=AnthropicContentParamSource(
|
||||
type="base64",
|
||||
media_type=image_chunk["media_type"],
|
||||
data=image_chunk["data"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
_anthropic_content_element = AnthropicMessagesImageParam(
|
||||
type="image",
|
||||
source=AnthropicContentParamSource(
|
||||
type="base64",
|
||||
media_type=image_chunk["media_type"],
|
||||
data=image_chunk["data"],
|
||||
),
|
||||
)
|
||||
return _anthropic_content_element
|
||||
|
||||
|
||||
def anthropic_messages_pt( # noqa: PLR0915
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
|
@ -1400,15 +1392,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
openai_image_url=m["image_url"]["url"]
|
||||
)
|
||||
|
||||
_anthropic_content_element = AnthropicMessagesImageParam(
|
||||
type="image",
|
||||
source=AnthropicImageParamSource(
|
||||
type="base64",
|
||||
media_type=image_chunk["media_type"],
|
||||
data=image_chunk["data"],
|
||||
),
|
||||
_anthropic_content_element = (
|
||||
_anthropic_content_element_factory(image_chunk)
|
||||
)
|
||||
|
||||
_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_content_element,
|
||||
orignal_content_element=dict(m),
|
||||
|
@ -2830,7 +2816,7 @@ def prompt_factory(
|
|||
else:
|
||||
return gemini_text_image_pt(messages=messages)
|
||||
elif custom_llm_provider == "mistral":
|
||||
return mistral_api_pt(messages=messages)
|
||||
return litellm.MistralConfig._transform_messages(messages=messages)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if "amazon.titan-text" in model:
|
||||
return amazon_titan_pt(messages=messages)
|
||||
|
|
|
@ -51,6 +51,9 @@ from ..common_utils import (
|
|||
|
||||
|
||||
def _process_gemini_image(image_url: str) -> PartType:
|
||||
"""
|
||||
Given an image URL, return the appropriate PartType for Gemini
|
||||
"""
|
||||
try:
|
||||
# GCS URIs
|
||||
if "gs://" in image_url:
|
||||
|
@ -68,9 +71,14 @@ def _process_gemini_image(image_url: str) -> PartType:
|
|||
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
|
||||
|
||||
return PartType(file_data=file_data)
|
||||
|
||||
# Direct links
|
||||
elif "https:/" in image_url or "base64" in image_url:
|
||||
elif (
|
||||
"https://" in image_url
|
||||
and (image_type := _get_image_mime_type_from_url(image_url)) is not None
|
||||
):
|
||||
file_data = FileDataType(file_uri=image_url, mime_type=image_type)
|
||||
return PartType(file_data=file_data)
|
||||
elif "https://" in image_url or "base64" in image_url:
|
||||
# https links for unsupported mime types and base64 images
|
||||
image = convert_to_anthropic_image_obj(image_url)
|
||||
_blob = BlobType(data=image["data"], mime_type=image["media_type"])
|
||||
return PartType(inline_data=_blob)
|
||||
|
@ -79,6 +87,29 @@ def _process_gemini_image(image_url: str) -> PartType:
|
|||
raise e
|
||||
|
||||
|
||||
def _get_image_mime_type_from_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
Get mime type for common image URLs
|
||||
See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements
|
||||
|
||||
Supported by Gemini:
|
||||
- PNG (`image/png`)
|
||||
- JPEG (`image/jpeg`)
|
||||
- WebP (`image/webp`)
|
||||
Example:
|
||||
url = https://example.com/image.jpg
|
||||
Returns: image/jpeg
|
||||
"""
|
||||
url = url.lower()
|
||||
if url.endswith((".jpg", ".jpeg")):
|
||||
return "image/jpeg"
|
||||
elif url.endswith(".png"):
|
||||
return "image/png"
|
||||
elif url.endswith(".webp"):
|
||||
return "image/webp"
|
||||
return None
|
||||
|
||||
|
||||
def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[ContentType]:
|
||||
|
|
123
litellm/llms/watsonx/chat/handler.py
Normal file
123
litellm/llms/watsonx/chat/handler.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
from typing import Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
|
||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||
|
||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from ..common_utils import WatsonXAIError, _get_api_params
|
||||
|
||||
|
||||
class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _prepare_url(
|
||||
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
|
||||
) -> str:
|
||||
if model.startswith("deployment/"):
|
||||
if api_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=api_params["url"],
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||
if stream is True
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.CHAT_STREAM.value
|
||||
if stream is True
|
||||
else WatsonXAIEndpoint.CHAT.value
|
||||
)
|
||||
base_url = httpx.URL(api_params["url"])
|
||||
base_url = base_url.join(endpoint)
|
||||
full_url = str(
|
||||
base_url.copy_add_param(key="version", value=api_params["api_version"])
|
||||
)
|
||||
|
||||
return full_url
|
||||
|
||||
def _prepare_payload(
|
||||
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
|
||||
) -> dict:
|
||||
payload: dict = {}
|
||||
if model.startswith("deployment/"):
|
||||
return payload
|
||||
payload["model_id"] = model
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
return payload
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[
|
||||
CustomStreamingDecoder
|
||||
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
|
||||
):
|
||||
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
|
||||
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_params['token']}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
stream: Optional[bool] = optional_params.get("stream", False)
|
||||
|
||||
## get api url and payload
|
||||
api_base = self._prepare_url(model=model, api_params=api_params, stream=stream)
|
||||
watsonx_auth_payload = self._prepare_payload(
|
||||
model=model, api_params=api_params, stream=stream
|
||||
)
|
||||
optional_params.update(watsonx_auth_payload)
|
||||
|
||||
return super().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_endpoint=True,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
82
litellm/llms/watsonx/chat/transformation.py
Normal file
82
litellm/llms/watsonx/chat/transformation.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
"""
|
||||
Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint.
|
||||
|
||||
Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class IBMWatsonXChatConfig(OpenAIGPTConfig):
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
"max_tokens", # equivalent to max_new_tokens
|
||||
"top_p", # equivalent to top_p
|
||||
"frequency_penalty", # equivalent to repetition_penalty
|
||||
"stop", # equivalent to stop_sequences
|
||||
"seed", # equivalent to random_seed
|
||||
"stream", # equivalent to stream
|
||||
"tools",
|
||||
"tool_choice", # equivalent to tool_choice + tool_choice_options
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool:
|
||||
if tool_choice is None:
|
||||
return False
|
||||
if isinstance(tool_choice, str):
|
||||
return tool_choice in ["auto", "none", "required"]
|
||||
return False
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
## TOOLS ##
|
||||
_tools = non_default_params.pop("tools", None)
|
||||
if _tools is not None:
|
||||
# remove 'additionalProperties' from tools
|
||||
_tools = _remove_additional_properties(_tools)
|
||||
# remove 'strict' from tools
|
||||
_tools = _remove_strict_from_schema(_tools)
|
||||
if _tools is not None:
|
||||
non_default_params["tools"] = _tools
|
||||
|
||||
## TOOL CHOICE ##
|
||||
|
||||
_tool_choice = non_default_params.pop("tool_choice", None)
|
||||
if self.is_tool_choice_option(_tool_choice):
|
||||
optional_params["tool_choice_options"] = _tool_choice
|
||||
elif _tool_choice is not None:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
return super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
||||
) # vllm does not require an api key
|
||||
return api_base, dynamic_api_key
|
172
litellm/llms/watsonx/common_utils.py
Normal file
172
litellm/llms/watsonx/common_utils.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
from typing import Callable, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
||||
|
||||
|
||||
class WatsonXAIError(Exception):
|
||||
def __init__(self, status_code, message, url: Optional[str] = None):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||
self.request = httpx.Request(method="POST", url=url)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
iam_token_cache = InMemoryCache()
|
||||
|
||||
|
||||
def generate_iam_token(api_key=None, **params) -> str:
|
||||
result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore
|
||||
|
||||
if result is None:
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError("API key is required")
|
||||
headers["Accept"] = "application/json"
|
||||
data = {
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": api_key,
|
||||
}
|
||||
verbose_logger.debug(
|
||||
"calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
|
||||
"https://iam.cloud.ibm.com/identity/token",
|
||||
headers,
|
||||
data,
|
||||
)
|
||||
response = httpx.post(
|
||||
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
json_data = response.json()
|
||||
|
||||
result = json_data["access_token"]
|
||||
iam_token_cache.set_cache(
|
||||
key=api_key,
|
||||
value=result,
|
||||
ttl=json_data["expires_in"] - 10, # leave some buffer
|
||||
)
|
||||
|
||||
return cast(str, result)
|
||||
|
||||
|
||||
def _get_api_params(
|
||||
params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
generate_token: Optional[bool] = True,
|
||||
) -> WatsonXAPIParams:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
# Load auth variables from params
|
||||
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||
api_key = params.pop("apikey", None)
|
||||
token = params.pop("token", None)
|
||||
project_id = params.pop(
|
||||
"project_id", params.pop("watsonx_project", None)
|
||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||
region_name = params.pop("region_name", params.pop("region", None))
|
||||
if region_name is None:
|
||||
region_name = params.pop(
|
||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||
) # consistent with how vertex ai + aws regions are accepted
|
||||
wx_credentials = params.pop(
|
||||
"wx_credentials",
|
||||
params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
|
||||
# Load auth variables from environment variables
|
||||
if url is None:
|
||||
url = (
|
||||
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
if token is None:
|
||||
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
get_secret_str("WATSONX_PROJECT_ID")
|
||||
or get_secret_str("WX_PROJECT_ID")
|
||||
or get_secret_str("PROJECT_ID")
|
||||
)
|
||||
if region_name is None:
|
||||
region_name = (
|
||||
get_secret_str("WATSONX_REGION")
|
||||
or get_secret_str("WX_REGION")
|
||||
or get_secret_str("REGION")
|
||||
)
|
||||
if space_id is None:
|
||||
space_id = (
|
||||
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||
or get_secret_str("WATSONX_SPACE_ID")
|
||||
or get_secret_str("WX_SPACE_ID")
|
||||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
# credentials parsing
|
||||
if wx_credentials is not None:
|
||||
url = wx_credentials.get("url", url)
|
||||
api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", token
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
|
||||
# verify that all required credentials are present
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
if token is None and api_key is not None and generate_token:
|
||||
# generate the auth token
|
||||
if print_verbose is not None:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if project_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
return WatsonXAPIParams(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
token=cast(str, token),
|
||||
project_id=project_id,
|
||||
space_id=space_id,
|
||||
region_name=region_name,
|
||||
api_version=api_version,
|
||||
)
|
|
@ -26,22 +26,12 @@ import requests # type: ignore
|
|||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||
|
||||
from .base import BaseLLM
|
||||
from .prompt_templates import factory as ptf
|
||||
|
||||
|
||||
class WatsonXAIError(Exception):
|
||||
def __init__(self, status_code, message, url: Optional[str] = None):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||
self.request = httpx.Request(method="POST", url=url)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
from ...base import BaseLLM
|
||||
from ...prompt_templates import factory as ptf
|
||||
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
||||
|
||||
|
||||
class IBMWatsonXAIConfig:
|
||||
|
@ -140,6 +130,29 @@ class IBMWatsonXAIConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def is_watsonx_text_param(self, param: str) -> bool:
|
||||
"""
|
||||
Determine if user passed in a watsonx.ai text generation param
|
||||
"""
|
||||
text_generation_params = [
|
||||
"decoding_method",
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"length_penalty",
|
||||
"stop_sequences",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
"truncate_input_tokens",
|
||||
"include_stop_sequences",
|
||||
"return_options",
|
||||
"random_seed",
|
||||
"moderations",
|
||||
"decoding_method",
|
||||
"min_tokens",
|
||||
]
|
||||
|
||||
return param in text_generation_params
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
|
@ -151,6 +164,44 @@ class IBMWatsonXAIConfig:
|
|||
"stream", # equivalent to stream
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
extra_body = {}
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_new_tokens"] = v
|
||||
elif k == "stream":
|
||||
optional_params["stream"] = v
|
||||
elif k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
elif k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
elif k == "frequency_penalty":
|
||||
optional_params["repetition_penalty"] = v
|
||||
elif k == "seed":
|
||||
optional_params["random_seed"] = v
|
||||
elif k == "stop":
|
||||
optional_params["stop_sequences"] = v
|
||||
elif k == "decoding_method":
|
||||
extra_body["decoding_method"] = v
|
||||
elif k == "min_tokens":
|
||||
extra_body["min_new_tokens"] = v
|
||||
elif k == "top_k":
|
||||
extra_body["top_k"] = v
|
||||
elif k == "truncate_input_tokens":
|
||||
extra_body["truncate_input_tokens"] = v
|
||||
elif k == "length_penalty":
|
||||
extra_body["length_penalty"] = v
|
||||
elif k == "time_limit":
|
||||
extra_body["time_limit"] = v
|
||||
elif k == "return_options":
|
||||
extra_body["return_options"] = v
|
||||
|
||||
if extra_body:
|
||||
optional_params["extra_body"] = extra_body
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
|
@ -212,18 +263,6 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) ->
|
|||
return prompt
|
||||
|
||||
|
||||
class WatsonXAIEndpoint(str, Enum):
|
||||
TEXT_GENERATION = "/ml/v1/text/generation"
|
||||
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
|
||||
DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation"
|
||||
DEPLOYMENT_TEXT_GENERATION_STREAM = (
|
||||
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
|
||||
)
|
||||
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||
PROMPTS = "/ml/v1/prompts"
|
||||
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
|
||||
|
||||
|
||||
class IBMWatsonXAI(BaseLLM):
|
||||
"""
|
||||
Class to interface with IBM watsonx.ai API for text generation and embeddings.
|
||||
|
@ -247,10 +286,10 @@ class IBMWatsonXAI(BaseLLM):
|
|||
"""
|
||||
Get the request parameters for text generation.
|
||||
"""
|
||||
api_params = self._get_api_params(optional_params, print_verbose=print_verbose)
|
||||
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
|
||||
# build auth headers
|
||||
api_token = api_params.get("token")
|
||||
|
||||
self.token = api_token
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
|
@ -294,118 +333,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
method="POST", url=url, headers=headers, json=payload, params=request_params
|
||||
)
|
||||
|
||||
def _get_api_params(
|
||||
self,
|
||||
params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
generate_token: Optional[bool] = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
# Load auth variables from params
|
||||
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||
api_key = params.pop("apikey", None)
|
||||
token = params.pop("token", None)
|
||||
project_id = params.pop(
|
||||
"project_id", params.pop("watsonx_project", None)
|
||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||
region_name = params.pop("region_name", params.pop("region", None))
|
||||
if region_name is None:
|
||||
region_name = params.pop(
|
||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||
) # consistent with how vertex ai + aws regions are accepted
|
||||
wx_credentials = params.pop(
|
||||
"wx_credentials",
|
||||
params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
|
||||
# Load auth variables from environment variables
|
||||
if url is None:
|
||||
url = (
|
||||
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
if token is None:
|
||||
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
get_secret_str("WATSONX_PROJECT_ID")
|
||||
or get_secret_str("WX_PROJECT_ID")
|
||||
or get_secret_str("PROJECT_ID")
|
||||
)
|
||||
if region_name is None:
|
||||
region_name = (
|
||||
get_secret_str("WATSONX_REGION")
|
||||
or get_secret_str("WX_REGION")
|
||||
or get_secret_str("REGION")
|
||||
)
|
||||
if space_id is None:
|
||||
space_id = (
|
||||
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||
or get_secret_str("WATSONX_SPACE_ID")
|
||||
or get_secret_str("WX_SPACE_ID")
|
||||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
# credentials parsing
|
||||
if wx_credentials is not None:
|
||||
url = wx_credentials.get("url", url)
|
||||
api_key = wx_credentials.get(
|
||||
"apikey", wx_credentials.get("api_key", api_key)
|
||||
)
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", token
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
|
||||
# verify that all required credentials are present
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if token is None and api_key is not None and generate_token:
|
||||
# generate the auth token
|
||||
if print_verbose is not None:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = self.generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if project_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"api_key": api_key,
|
||||
"token": token,
|
||||
"project_id": project_id,
|
||||
"space_id": space_id,
|
||||
"region_name": region_name,
|
||||
"api_version": api_version,
|
||||
}
|
||||
|
||||
def _process_text_gen_response(
|
||||
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
|
||||
) -> ModelResponse:
|
||||
|
@ -616,9 +543,10 @@ class IBMWatsonXAI(BaseLLM):
|
|||
input = [input]
|
||||
if api_key is not None:
|
||||
optional_params["api_key"] = api_key
|
||||
api_params = self._get_api_params(optional_params)
|
||||
api_params = _get_api_params(optional_params)
|
||||
# build auth headers
|
||||
api_token = api_params.get("token")
|
||||
self.token = api_token
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
|
@ -664,29 +592,9 @@ class IBMWatsonXAI(BaseLLM):
|
|||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
|
||||
def generate_iam_token(self, api_key=None, **params):
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError("API key is required")
|
||||
headers["Accept"] = "application/json"
|
||||
data = {
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": api_key,
|
||||
}
|
||||
response = httpx.post(
|
||||
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
json_data = response.json()
|
||||
iam_access_token = json_data["access_token"]
|
||||
self.token = iam_access_token
|
||||
return iam_access_token
|
||||
|
||||
def get_available_models(self, *, ids_only: bool = True, **params):
|
||||
api_params = self._get_api_params(params)
|
||||
api_params = _get_api_params(params)
|
||||
self.token = api_params["token"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_params['token']}",
|
||||
"Content-Type": "application/json",
|
|
@ -77,6 +77,7 @@ from litellm.utils import (
|
|||
read_config_args,
|
||||
supports_httpx_timeout,
|
||||
token_counter,
|
||||
validate_chat_completion_user_messages,
|
||||
)
|
||||
|
||||
from ._logging import verbose_logger
|
||||
|
@ -107,9 +108,9 @@ from .llms.azure_text import AzureTextCompletion
|
|||
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
|
||||
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
||||
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
|
||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||
from .llms.cohere import chat as cohere_chat
|
||||
from .llms.cohere import completion as cohere_completion # type: ignore
|
||||
from .llms.cohere.embed import handler as cohere_embed
|
||||
|
@ -157,11 +158,13 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
|||
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
|
||||
VertexEmbedding,
|
||||
)
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
||||
from .types.llms.openai import (
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionAudioParam,
|
||||
ChatCompletionModality,
|
||||
ChatCompletionPredictionContentParam,
|
||||
ChatCompletionUserMessage,
|
||||
HttpxBinaryResponseContent,
|
||||
)
|
||||
|
@ -211,6 +214,7 @@ triton_chat_completions = TritonChatCompletion()
|
|||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
bedrock_embedding = BedrockEmbedding()
|
||||
bedrock_image_generation = BedrockImageGeneration()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
vertex_embedding = VertexEmbedding()
|
||||
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
||||
|
@ -220,6 +224,7 @@ vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
|||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
sagemaker_llm = SagemakerLLM()
|
||||
watsonx_chat_completion = WatsonXChatHandler()
|
||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
@ -304,6 +309,7 @@ async def acompletion(
|
|||
max_tokens: Optional[int] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
modalities: Optional[List[ChatCompletionModality]] = None,
|
||||
prediction: Optional[ChatCompletionPredictionContentParam] = None,
|
||||
audio: Optional[ChatCompletionAudioParam] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
@ -346,6 +352,7 @@ async def acompletion(
|
|||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||
max_completion_tokens (integer, optional): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
|
||||
modalities (List[ChatCompletionModality], optional): Output types that you would like the model to generate for this request. You can use `["text", "audio"]`
|
||||
prediction (ChatCompletionPredictionContentParam, optional): Configuration for a Predicted Output, which can greatly improve response times when large parts of the model response are known ahead of time. This is most common when you are regenerating a file with only minor changes to most of the content.
|
||||
audio (ChatCompletionAudioParam, optional): Parameters for audio output. Required when audio output is requested with modalities: ["audio"]
|
||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||
frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far.
|
||||
|
@ -387,6 +394,7 @@ async def acompletion(
|
|||
"max_tokens": max_tokens,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"modalities": modalities,
|
||||
"prediction": prediction,
|
||||
"audio": audio,
|
||||
"presence_penalty": presence_penalty,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
|
@ -693,6 +701,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
modalities: Optional[List[ChatCompletionModality]] = None,
|
||||
prediction: Optional[ChatCompletionPredictionContentParam] = None,
|
||||
audio: Optional[ChatCompletionAudioParam] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
@ -737,6 +746,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||
max_completion_tokens (integer, optional): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
|
||||
modalities (List[ChatCompletionModality], optional): Output types that you would like the model to generate for this request.. You can use `["text", "audio"]`
|
||||
prediction (ChatCompletionPredictionContentParam, optional): Configuration for a Predicted Output, which can greatly improve response times when large parts of the model response are known ahead of time. This is most common when you are regenerating a file with only minor changes to most of the content.
|
||||
audio (ChatCompletionAudioParam, optional): Parameters for audio output. Required when audio output is requested with modalities: ["audio"]
|
||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||
frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far.
|
||||
|
@ -843,6 +853,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"stop",
|
||||
"max_completion_tokens",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"audio",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
|
@ -914,6 +925,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"aws_region_name", None
|
||||
) # support region-based pricing for bedrock
|
||||
|
||||
### VALIDATE USER MESSAGES ###
|
||||
validate_chat_completion_user_messages(messages=messages)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
@ -994,6 +1008,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
max_tokens=max_tokens,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
modalities=modalities,
|
||||
prediction=prediction,
|
||||
audio=audio,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
@ -2607,6 +2622,26 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
## RESPONSE OBJECT
|
||||
response = response
|
||||
elif custom_llm_provider == "watsonx":
|
||||
response = watsonx_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
encoding=encoding,
|
||||
custom_llm_provider="watsonx",
|
||||
)
|
||||
elif custom_llm_provider == "watsonx_text":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = watsonxai.completion(
|
||||
model=model,
|
||||
|
@ -3867,34 +3902,17 @@ async def atext_completion(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
else:
|
||||
transformed_logprobs = None
|
||||
# only supported for TGI models
|
||||
try:
|
||||
raw_response = response._hidden_params.get("original_response", None)
|
||||
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM non blocking exception: {e}")
|
||||
|
||||
## TRANSLATE CHAT TO TEXT FORMAT ##
|
||||
## OpenAI / Azure Text Completion Returns here
|
||||
if isinstance(response, TextCompletionResponse):
|
||||
return response
|
||||
elif asyncio.iscoroutine(response):
|
||||
response = await response
|
||||
|
||||
text_completion_response = TextCompletionResponse()
|
||||
text_completion_response["id"] = response.get("id", None)
|
||||
text_completion_response["object"] = "text_completion"
|
||||
text_completion_response["created"] = response.get("created", None)
|
||||
text_completion_response["model"] = response.get("model", None)
|
||||
text_choices = TextChoices()
|
||||
text_choices["text"] = response["choices"][0]["message"]["content"]
|
||||
text_choices["index"] = response["choices"][0]["index"]
|
||||
text_choices["logprobs"] = transformed_logprobs
|
||||
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
||||
text_completion_response["choices"] = [text_choices]
|
||||
text_completion_response["usage"] = response.get("usage", None)
|
||||
text_completion_response._hidden_params = HiddenParams(
|
||||
**response._hidden_params
|
||||
text_completion_response = litellm.utils.LiteLLMResponseObjectHandler.convert_chat_to_text_completion(
|
||||
text_completion_response=text_completion_response,
|
||||
response=response,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
return text_completion_response
|
||||
except Exception as e:
|
||||
|
@ -4156,29 +4174,17 @@ def text_completion( # noqa: PLR0915
|
|||
return response
|
||||
elif isinstance(response, TextCompletionStreamWrapper):
|
||||
return response
|
||||
transformed_logprobs = None
|
||||
# only supported for TGI models
|
||||
try:
|
||||
raw_response = response._hidden_params.get("original_response", None)
|
||||
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
|
||||
|
||||
# OpenAI Text / Azure Text will return here
|
||||
if isinstance(response, TextCompletionResponse):
|
||||
return response
|
||||
|
||||
text_completion_response["id"] = response.get("id", None)
|
||||
text_completion_response["object"] = "text_completion"
|
||||
text_completion_response["created"] = response.get("created", None)
|
||||
text_completion_response["model"] = response.get("model", None)
|
||||
text_choices = TextChoices()
|
||||
text_choices["text"] = response["choices"][0]["message"]["content"]
|
||||
text_choices["index"] = response["choices"][0]["index"]
|
||||
text_choices["logprobs"] = transformed_logprobs
|
||||
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
||||
text_completion_response["choices"] = [text_choices]
|
||||
text_completion_response["usage"] = response.get("usage", None)
|
||||
text_completion_response._hidden_params = HiddenParams(**response._hidden_params)
|
||||
text_completion_response = (
|
||||
litellm.utils.LiteLLMResponseObjectHandler.convert_chat_to_text_completion(
|
||||
response=response,
|
||||
text_completion_response=text_completion_response,
|
||||
)
|
||||
)
|
||||
|
||||
return text_completion_response
|
||||
|
||||
|
@ -4314,9 +4320,9 @@ async def amoderation(
|
|||
else:
|
||||
_openai_client = openai_client
|
||||
if model is not None:
|
||||
response = await openai_client.moderations.create(input=input, model=model)
|
||||
response = await _openai_client.moderations.create(input=input, model=model)
|
||||
else:
|
||||
response = await openai_client.moderations.create(input=input)
|
||||
response = await _openai_client.moderations.create(input=input)
|
||||
return response
|
||||
|
||||
|
||||
|
@ -4442,6 +4448,7 @@ def image_generation( # noqa: PLR0915
|
|||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = get_optional_params_image_gen(
|
||||
model=model,
|
||||
n=n,
|
||||
quality=quality,
|
||||
response_format=response_format,
|
||||
|
@ -4534,7 +4541,7 @@ def image_generation( # noqa: PLR0915
|
|||
elif custom_llm_provider == "bedrock":
|
||||
if model is None:
|
||||
raise Exception("Model needs to be set for bedrock")
|
||||
model_response = bedrock_image_generation.image_generation(
|
||||
model_response = bedrock_image_generation.image_generation( # type: ignore
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
|
|
|
@ -80,6 +80,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
|
@ -94,6 +95,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
|
@ -108,7 +110,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"o1-mini-2024-09-12": {
|
||||
|
@ -122,7 +124,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"o1-preview": {
|
||||
|
@ -136,7 +138,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"o1-preview-2024-09-12": {
|
||||
|
@ -150,7 +152,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"chatgpt-4o-latest": {
|
||||
|
@ -190,6 +192,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
|
@ -461,6 +464,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"ft:gpt-4o-mini-2024-07-18": {
|
||||
|
@ -473,6 +477,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"ft:davinci-002": {
|
||||
|
@ -652,7 +657,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"azure/o1-mini-2024-09-12": {
|
||||
|
@ -666,7 +671,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"azure/o1-preview": {
|
||||
|
@ -680,7 +685,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"azure/o1-preview-2024-09-12": {
|
||||
|
@ -694,7 +699,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"azure/gpt-4o": {
|
||||
|
@ -721,6 +726,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"azure/gpt-4o-2024-05-13": {
|
||||
|
@ -746,6 +752,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"azure/global-standard/gpt-4o-mini": {
|
||||
|
@ -758,6 +765,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"azure/gpt-4o-mini": {
|
||||
|
@ -771,6 +779,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
|
@ -785,6 +794,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
|
@ -1109,6 +1119,52 @@
|
|||
"supports_function_calling": true,
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure_ai/mistral-large-2407": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000002,
|
||||
"output_cost_per_token": 0.000006,
|
||||
"litellm_provider": "azure_ai",
|
||||
"supports_function_calling": true,
|
||||
"mode": "chat",
|
||||
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview"
|
||||
},
|
||||
"azure_ai/ministral-3b": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000004,
|
||||
"output_cost_per_token": 0.00000004,
|
||||
"litellm_provider": "azure_ai",
|
||||
"supports_function_calling": true,
|
||||
"mode": "chat",
|
||||
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.ministral-3b-2410-offer?tab=Overview"
|
||||
},
|
||||
"azure_ai/Llama-3.2-11B-Vision-Instruct": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 2048,
|
||||
"input_cost_per_token": 0.00000037,
|
||||
"output_cost_per_token": 0.00000037,
|
||||
"litellm_provider": "azure_ai",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"mode": "chat",
|
||||
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/metagenai.meta-llama-3-2-11b-vision-instruct-offer?tab=Overview"
|
||||
},
|
||||
"azure_ai/Llama-3.2-90B-Vision-Instruct": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 2048,
|
||||
"input_cost_per_token": 0.00000204,
|
||||
"output_cost_per_token": 0.00000204,
|
||||
"litellm_provider": "azure_ai",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"mode": "chat",
|
||||
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/metagenai.meta-llama-3-2-90b-vision-instruct-offer?tab=Overview"
|
||||
},
|
||||
"azure_ai/Meta-Llama-3-70B-Instruct": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
|
@ -1148,6 +1204,105 @@
|
|||
"mode": "chat",
|
||||
"source":"https://azuremarketplace.microsoft.com/en-us/marketplace/apps/metagenai.meta-llama-3-1-405b-instruct-offer?tab=PlansAndPrice"
|
||||
},
|
||||
"azure_ai/Phi-3.5-mini-instruct": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000013,
|
||||
"output_cost_per_token": 0.00000052,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": false,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/Phi-3.5-vision-instruct": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000013,
|
||||
"output_cost_per_token": 0.00000052,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": true,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/Phi-3.5-MoE-instruct": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000016,
|
||||
"output_cost_per_token": 0.00000064,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": false,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/Phi-3-mini-4k-instruct": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000013,
|
||||
"output_cost_per_token": 0.00000052,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": false,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/Phi-3-mini-128k-instruct": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000013,
|
||||
"output_cost_per_token": 0.00000052,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": false,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/Phi-3-small-8k-instruct": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.0000006,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": false,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/Phi-3-small-128k-instruct": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.0000006,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": false,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/Phi-3-medium-4k-instruct": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000017,
|
||||
"output_cost_per_token": 0.00000068,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": false,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/Phi-3-medium-128k-instruct": {
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000017,
|
||||
"output_cost_per_token": 0.00000068,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_vision": false,
|
||||
"source": "https://azure.microsoft.com/en-us/pricing/details/phi-3/"
|
||||
},
|
||||
"azure_ai/cohere-rerank-v3-multilingual": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
@ -1730,6 +1885,22 @@
|
|||
"supports_assistant_prefill": true,
|
||||
"supports_prompt_caching": true
|
||||
},
|
||||
"claude-3-5-haiku-20241022": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"cache_creation_input_token_cost": 0.00000125,
|
||||
"cache_read_input_token_cost": 0.0000001,
|
||||
"litellm_provider": "anthropic",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"tool_use_system_prompt_tokens": 264,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_pdf_input": true
|
||||
},
|
||||
"claude-3-opus-20240229": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
|
@ -2195,16 +2366,16 @@
|
|||
"input_cost_per_image": 0.00032875,
|
||||
"input_cost_per_audio_per_second": 0.00003125,
|
||||
"input_cost_per_video_per_second": 0.00032875,
|
||||
"input_cost_per_token": 0.000000078125,
|
||||
"input_cost_per_character": 0.0000003125,
|
||||
"input_cost_per_token": 0.00000125,
|
||||
"input_cost_per_character": 0.0000003125,
|
||||
"input_cost_per_image_above_128k_tokens": 0.0006575,
|
||||
"input_cost_per_video_per_second_above_128k_tokens": 0.0006575,
|
||||
"input_cost_per_audio_per_second_above_128k_tokens": 0.0000625,
|
||||
"input_cost_per_token_above_128k_tokens": 0.00000015625,
|
||||
"input_cost_per_character_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token": 0.0000003125,
|
||||
"input_cost_per_token_above_128k_tokens": 0.0000025,
|
||||
"input_cost_per_character_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"output_cost_per_character": 0.00000125,
|
||||
"output_cost_per_token_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token_above_128k_tokens": 0.00001,
|
||||
"output_cost_per_character_above_128k_tokens": 0.0000025,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
|
@ -2221,16 +2392,16 @@
|
|||
"input_cost_per_image": 0.00032875,
|
||||
"input_cost_per_audio_per_second": 0.00003125,
|
||||
"input_cost_per_video_per_second": 0.00032875,
|
||||
"input_cost_per_token": 0.000000078125,
|
||||
"input_cost_per_character": 0.0000003125,
|
||||
"input_cost_per_token": 0.00000125,
|
||||
"input_cost_per_character": 0.0000003125,
|
||||
"input_cost_per_image_above_128k_tokens": 0.0006575,
|
||||
"input_cost_per_video_per_second_above_128k_tokens": 0.0006575,
|
||||
"input_cost_per_audio_per_second_above_128k_tokens": 0.0000625,
|
||||
"input_cost_per_token_above_128k_tokens": 0.00000015625,
|
||||
"input_cost_per_character_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token": 0.0000003125,
|
||||
"input_cost_per_token_above_128k_tokens": 0.0000025,
|
||||
"input_cost_per_character_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"output_cost_per_character": 0.00000125,
|
||||
"output_cost_per_token_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token_above_128k_tokens": 0.00001,
|
||||
"output_cost_per_character_above_128k_tokens": 0.0000025,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
|
@ -2247,16 +2418,16 @@
|
|||
"input_cost_per_image": 0.00032875,
|
||||
"input_cost_per_audio_per_second": 0.00003125,
|
||||
"input_cost_per_video_per_second": 0.00032875,
|
||||
"input_cost_per_token": 0.000000078125,
|
||||
"input_cost_per_character": 0.0000003125,
|
||||
"input_cost_per_token": 0.00000125,
|
||||
"input_cost_per_character": 0.0000003125,
|
||||
"input_cost_per_image_above_128k_tokens": 0.0006575,
|
||||
"input_cost_per_video_per_second_above_128k_tokens": 0.0006575,
|
||||
"input_cost_per_audio_per_second_above_128k_tokens": 0.0000625,
|
||||
"input_cost_per_token_above_128k_tokens": 0.00000015625,
|
||||
"input_cost_per_character_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token": 0.0000003125,
|
||||
"input_cost_per_token_above_128k_tokens": 0.0000025,
|
||||
"input_cost_per_character_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"output_cost_per_character": 0.00000125,
|
||||
"output_cost_per_token_above_128k_tokens": 0.000000625,
|
||||
"output_cost_per_token_above_128k_tokens": 0.00001,
|
||||
"output_cost_per_character_above_128k_tokens": 0.0000025,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
|
@ -2356,17 +2527,17 @@
|
|||
"input_cost_per_image": 0.00002,
|
||||
"input_cost_per_video_per_second": 0.00002,
|
||||
"input_cost_per_audio_per_second": 0.000002,
|
||||
"input_cost_per_token": 0.000000004688,
|
||||
"input_cost_per_token": 0.000000075,
|
||||
"input_cost_per_character": 0.00000001875,
|
||||
"input_cost_per_token_above_128k_tokens": 0.000001,
|
||||
"input_cost_per_character_above_128k_tokens": 0.00000025,
|
||||
"input_cost_per_image_above_128k_tokens": 0.00004,
|
||||
"input_cost_per_video_per_second_above_128k_tokens": 0.00004,
|
||||
"input_cost_per_audio_per_second_above_128k_tokens": 0.000004,
|
||||
"output_cost_per_token": 0.0000000046875,
|
||||
"output_cost_per_character": 0.00000001875,
|
||||
"output_cost_per_token_above_128k_tokens": 0.000000009375,
|
||||
"output_cost_per_character_above_128k_tokens": 0.0000000375,
|
||||
"output_cost_per_token": 0.0000003,
|
||||
"output_cost_per_character": 0.000000075,
|
||||
"output_cost_per_token_above_128k_tokens": 0.0000006,
|
||||
"output_cost_per_character_above_128k_tokens": 0.00000015,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
|
@ -2420,17 +2591,17 @@
|
|||
"input_cost_per_image": 0.00002,
|
||||
"input_cost_per_video_per_second": 0.00002,
|
||||
"input_cost_per_audio_per_second": 0.000002,
|
||||
"input_cost_per_token": 0.000000004688,
|
||||
"input_cost_per_token": 0.000000075,
|
||||
"input_cost_per_character": 0.00000001875,
|
||||
"input_cost_per_token_above_128k_tokens": 0.000001,
|
||||
"input_cost_per_character_above_128k_tokens": 0.00000025,
|
||||
"input_cost_per_image_above_128k_tokens": 0.00004,
|
||||
"input_cost_per_video_per_second_above_128k_tokens": 0.00004,
|
||||
"input_cost_per_audio_per_second_above_128k_tokens": 0.000004,
|
||||
"output_cost_per_token": 0.0000000046875,
|
||||
"output_cost_per_character": 0.00000001875,
|
||||
"output_cost_per_token_above_128k_tokens": 0.000000009375,
|
||||
"output_cost_per_character_above_128k_tokens": 0.0000000375,
|
||||
"output_cost_per_token": 0.0000003,
|
||||
"output_cost_per_character": 0.000000075,
|
||||
"output_cost_per_token_above_128k_tokens": 0.0000006,
|
||||
"output_cost_per_character_above_128k_tokens": 0.00000015,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
|
@ -2452,17 +2623,17 @@
|
|||
"input_cost_per_image": 0.00002,
|
||||
"input_cost_per_video_per_second": 0.00002,
|
||||
"input_cost_per_audio_per_second": 0.000002,
|
||||
"input_cost_per_token": 0.000000004688,
|
||||
"input_cost_per_token": 0.000000075,
|
||||
"input_cost_per_character": 0.00000001875,
|
||||
"input_cost_per_token_above_128k_tokens": 0.000001,
|
||||
"input_cost_per_character_above_128k_tokens": 0.00000025,
|
||||
"input_cost_per_image_above_128k_tokens": 0.00004,
|
||||
"input_cost_per_video_per_second_above_128k_tokens": 0.00004,
|
||||
"input_cost_per_audio_per_second_above_128k_tokens": 0.000004,
|
||||
"output_cost_per_token": 0.0000000046875,
|
||||
"output_cost_per_character": 0.00000001875,
|
||||
"output_cost_per_token_above_128k_tokens": 0.000000009375,
|
||||
"output_cost_per_character_above_128k_tokens": 0.0000000375,
|
||||
"output_cost_per_token": 0.0000003,
|
||||
"output_cost_per_character": 0.000000075,
|
||||
"output_cost_per_token_above_128k_tokens": 0.0000006,
|
||||
"output_cost_per_character_above_128k_tokens": 0.00000015,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
|
@ -2484,7 +2655,7 @@
|
|||
"input_cost_per_image": 0.00002,
|
||||
"input_cost_per_video_per_second": 0.00002,
|
||||
"input_cost_per_audio_per_second": 0.000002,
|
||||
"input_cost_per_token": 0.000000004688,
|
||||
"input_cost_per_token": 0.000000075,
|
||||
"input_cost_per_character": 0.00000001875,
|
||||
"input_cost_per_token_above_128k_tokens": 0.000001,
|
||||
"input_cost_per_character_above_128k_tokens": 0.00000025,
|
||||
|
@ -2643,6 +2814,17 @@
|
|||
"supports_vision": true,
|
||||
"supports_assistant_prefill": true
|
||||
},
|
||||
"vertex_ai/claude-3-5-haiku@20241022": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"litellm_provider": "vertex_ai-anthropic_models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true
|
||||
},
|
||||
"vertex_ai/claude-3-opus@20240229": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
|
@ -2686,14 +2868,15 @@
|
|||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
|
||||
},
|
||||
"vertex_ai/meta/llama-3.2-90b-vision-instruct-maas": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 128000,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 2048,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "vertex_ai-llama_models",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
"supports_vision": true,
|
||||
"source": "https://console.cloud.google.com/vertex-ai/publishers/meta/model-garden/llama-3.2-90b-vision-instruct-maas"
|
||||
},
|
||||
"vertex_ai/mistral-large@latest": {
|
||||
|
@ -3615,6 +3798,14 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/anthropic/claude-3-5-haiku": {
|
||||
"max_tokens": 200000,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"openrouter/anthropic/claude-3-haiku-20240307": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
|
@ -3627,6 +3818,17 @@
|
|||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 264
|
||||
},
|
||||
"openrouter/anthropic/claude-3-5-haiku-20241022": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"tool_use_system_prompt_tokens": 264
|
||||
},
|
||||
"anthropic/claude-3-5-sonnet-20241022": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 200000,
|
||||
|
@ -3747,7 +3949,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
"supports_vision": false
|
||||
},
|
||||
"openrouter/openai/o1-mini-2024-09-12": {
|
||||
"max_tokens": 65536,
|
||||
|
@ -3759,7 +3961,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
"supports_vision": false
|
||||
},
|
||||
"openrouter/openai/o1-preview": {
|
||||
"max_tokens": 32768,
|
||||
|
@ -3771,7 +3973,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
"supports_vision": false
|
||||
},
|
||||
"openrouter/openai/o1-preview-2024-09-12": {
|
||||
"max_tokens": 32768,
|
||||
|
@ -3783,7 +3985,7 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
"supports_vision": false
|
||||
},
|
||||
"openrouter/openai/gpt-4o": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -4330,9 +4532,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 4096,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -4352,6 +4554,17 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"anthropic.claude-3-opus-20240229-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
|
@ -4386,9 +4599,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"us.anthropic.claude-3-5-sonnet-20241022-v2:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 4096,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -4408,6 +4621,17 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"us.anthropic.claude-3-5-haiku-20241022-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"us.anthropic.claude-3-opus-20240229-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
|
@ -4442,9 +4666,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"eu.anthropic.claude-3-5-sonnet-20241022-v2:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 4096,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -4464,6 +4688,16 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"eu.anthropic.claude-3-5-haiku-20241022-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000005,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"eu.anthropic.claude-3-opus-20240229-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
|
@ -5378,6 +5612,13 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"stability.sd3-large-v1:0": {
|
||||
"max_tokens": 77,
|
||||
"max_input_tokens": 77,
|
||||
"output_cost_per_image": 0.08,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "image_generation"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
model_list:
|
||||
- model_name: claude-3-5-sonnet-20240620
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: claude-3-5-sonnet-20240620
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
|
@ -10,63 +10,113 @@ model_list:
|
|||
output_cost_per_token: 0.000015 # 15$/M
|
||||
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
|
||||
api_key: my-fake-key
|
||||
- model_name: gemini-1.5-flash-002
|
||||
- model_name: fake-openai-endpoint-2
|
||||
litellm_params:
|
||||
model: gemini/gemini-1.5-flash-002
|
||||
|
||||
# litellm_settings:
|
||||
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||
# callbacks: ["otel", "prometheus"]
|
||||
# default_redis_batch_cache_expiry: 10
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
stream_timeout: 0.001
|
||||
timeout: 1
|
||||
rpm: 1
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
## bedrock chat completions
|
||||
- model_name: "*anthropic.claude*"
|
||||
litellm_params:
|
||||
model: bedrock/*anthropic.claude*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
guardrailConfig:
|
||||
"guardrailIdentifier": "h4dsqwhp6j66"
|
||||
"guardrailVersion": "2"
|
||||
"trace": "enabled"
|
||||
|
||||
## bedrock embeddings
|
||||
- model_name: "*amazon.titan-embed-*"
|
||||
litellm_params:
|
||||
model: bedrock/amazon.titan-embed-*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
- model_name: "*cohere.embed-*"
|
||||
litellm_params:
|
||||
model: bedrock/cohere.embed-*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
|
||||
- model_name: "bedrock/*"
|
||||
litellm_params:
|
||||
model: bedrock/*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||
rpm: 480
|
||||
timeout: 300
|
||||
stream_timeout: 60
|
||||
|
||||
litellm_settings:
|
||||
cache: True
|
||||
cache_params:
|
||||
type: redis
|
||||
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||
callbacks: ["otel", "prometheus"]
|
||||
default_redis_batch_cache_expiry: 10
|
||||
# default_team_settings:
|
||||
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
|
||||
# success_callback: ["langfuse"]
|
||||
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
|
||||
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
|
||||
|
||||
# disable caching on the actual API call
|
||||
supported_call_types: []
|
||||
# litellm_settings:
|
||||
# cache: True
|
||||
# cache_params:
|
||||
# type: redis
|
||||
|
||||
# see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
|
||||
host: os.environ/REDIS_HOST
|
||||
port: os.environ/REDIS_PORT
|
||||
password: os.environ/REDIS_PASSWORD
|
||||
# # disable caching on the actual API call
|
||||
# supported_call_types: []
|
||||
|
||||
# see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
|
||||
# see https://docs.litellm.ai/docs/proxy/prometheus
|
||||
callbacks: ['prometheus', 'otel']
|
||||
# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
|
||||
# host: os.environ/REDIS_HOST
|
||||
# port: os.environ/REDIS_PORT
|
||||
# password: os.environ/REDIS_PASSWORD
|
||||
|
||||
# # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry
|
||||
failure_callback: ['sentry']
|
||||
service_callback: ['prometheus_system']
|
||||
|
||||
# redact_user_api_key_info: true
|
||||
# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
|
||||
# # see https://docs.litellm.ai/docs/proxy/prometheus
|
||||
# callbacks: ['otel']
|
||||
|
||||
|
||||
router_settings:
|
||||
routing_strategy: latency-based-routing
|
||||
routing_strategy_args:
|
||||
# only assign 40% of traffic to the fastest deployment to avoid overloading it
|
||||
lowest_latency_buffer: 0.4
|
||||
# # router_settings:
|
||||
# # routing_strategy: latency-based-routing
|
||||
# # routing_strategy_args:
|
||||
# # # only assign 40% of traffic to the fastest deployment to avoid overloading it
|
||||
# # lowest_latency_buffer: 0.4
|
||||
|
||||
# consider last five minutes of calls for latency calculation
|
||||
ttl: 300
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
# # # consider last five minutes of calls for latency calculation
|
||||
# # ttl: 300
|
||||
# # redis_host: os.environ/REDIS_HOST
|
||||
# # redis_port: os.environ/REDIS_PORT
|
||||
# # redis_password: os.environ/REDIS_PASSWORD
|
||||
|
||||
# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
|
||||
# # general_settings:
|
||||
# # master_key: os.environ/LITELLM_MASTER_KEY
|
||||
# # database_url: os.environ/DATABASE_URL
|
||||
# # disable_master_key_return: true
|
||||
# # # alerting: ['slack', 'email']
|
||||
# # alerting: ['email']
|
||||
|
||||
# see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
|
||||
general_settings:
|
||||
master_key: os.environ/LITELLM_MASTER_KEY
|
||||
database_url: os.environ/DATABASE_URL
|
||||
disable_master_key_return: true
|
||||
# alerting: ['slack', 'email']
|
||||
alerting: ['email']
|
||||
# # # Batch write spend updates every 60s
|
||||
# # proxy_batch_write_at: 60
|
||||
|
||||
# Batch write spend updates every 60s
|
||||
proxy_batch_write_at: 60
|
||||
|
||||
# see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
|
||||
# our api keys rarely change
|
||||
user_api_key_cache_ttl: 3600
|
||||
# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
|
||||
# # # our api keys rarely change
|
||||
# # user_api_key_cache_ttl: 3600
|
||||
|
|
|
@ -436,15 +436,7 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
|||
"""
|
||||
|
||||
admin_jwt_scope: str = "litellm_proxy_admin"
|
||||
admin_allowed_routes: List[
|
||||
Literal[
|
||||
"openai_routes",
|
||||
"info_routes",
|
||||
"management_routes",
|
||||
"spend_tracking_routes",
|
||||
"global_spend_tracking_routes",
|
||||
]
|
||||
] = [
|
||||
admin_allowed_routes: List[str] = [
|
||||
"management_routes",
|
||||
"spend_tracking_routes",
|
||||
"global_spend_tracking_routes",
|
||||
|
@ -1902,6 +1894,7 @@ class ProxyErrorTypes(str, enum.Enum):
|
|||
auth_error = "auth_error"
|
||||
internal_server_error = "internal_server_error"
|
||||
bad_request_error = "bad_request_error"
|
||||
not_found_error = "not_found_error"
|
||||
|
||||
|
||||
class SSOUserDefinedValues(TypedDict):
|
||||
|
|
|
@ -13,11 +13,13 @@ import traceback
|
|||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
|
@ -30,7 +32,7 @@ from litellm.proxy._types import (
|
|||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
||||
from .auth_checks_organization import organization_role_based_access_check
|
||||
|
@ -42,6 +44,10 @@ if TYPE_CHECKING:
|
|||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
|
||||
db_cache_expiry = 5 # refresh every 5s
|
||||
|
||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||
|
||||
|
||||
|
@ -284,7 +290,7 @@ def get_actual_routes(allowed_routes: list) -> list:
|
|||
return actual_routes
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
@log_db_metrics
|
||||
async def get_end_user_object(
|
||||
end_user_id: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -383,7 +389,33 @@ def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
def _should_check_db(
|
||||
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
|
||||
) -> bool:
|
||||
"""
|
||||
Prevent calling db repeatedly for items that don't exist in the db.
|
||||
"""
|
||||
current_time = time.time()
|
||||
# if key doesn't exist in last_db_access_time -> check db
|
||||
if key not in last_db_access_time:
|
||||
return True
|
||||
elif (
|
||||
last_db_access_time[key][0] is not None
|
||||
): # check db for non-null values (for refresh operations)
|
||||
return True
|
||||
elif last_db_access_time[key][0] is None:
|
||||
if current_time - last_db_access_time[key] >= db_cache_expiry:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _update_last_db_access_time(
|
||||
key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
|
||||
):
|
||||
last_db_access_time[key] = (value, time.time())
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def get_user_object(
|
||||
user_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -412,11 +444,20 @@ async def get_user_object(
|
|||
if prisma_client is None:
|
||||
raise Exception("No db connected")
|
||||
try:
|
||||
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}, include={"organization_memberships": True}
|
||||
db_access_time_key = "user_id:{}".format(user_id)
|
||||
should_check_db = _should_check_db(
|
||||
key=db_access_time_key,
|
||||
last_db_access_time=last_db_access_time,
|
||||
db_cache_expiry=db_cache_expiry,
|
||||
)
|
||||
|
||||
if should_check_db:
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}, include={"organization_memberships": True}
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
if user_id_upsert:
|
||||
response = await prisma_client.db.litellm_usertable.create(
|
||||
|
@ -444,6 +485,13 @@ async def get_user_object(
|
|||
# save the user object to cache
|
||||
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
|
||||
|
||||
# save to db access time
|
||||
_update_last_db_access_time(
|
||||
key=db_access_time_key,
|
||||
value=response_dict,
|
||||
last_db_access_time=last_db_access_time,
|
||||
)
|
||||
|
||||
return _response
|
||||
except Exception as e: # if user not in db
|
||||
raise ValueError(
|
||||
|
@ -514,7 +562,13 @@ async def _delete_cache_key_object(
|
|||
)
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
@log_db_metrics
|
||||
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
|
||||
async def get_team_object(
|
||||
team_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -544,7 +598,7 @@ async def get_team_object(
|
|||
):
|
||||
cached_team_obj = (
|
||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
|
||||
key=key
|
||||
key=key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -564,9 +618,18 @@ async def get_team_object(
|
|||
|
||||
# else, check db
|
||||
try:
|
||||
response = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
db_access_time_key = "team_id:{}".format(team_id)
|
||||
should_check_db = _should_check_db(
|
||||
key=db_access_time_key,
|
||||
last_db_access_time=last_db_access_time,
|
||||
db_cache_expiry=db_cache_expiry,
|
||||
)
|
||||
if should_check_db:
|
||||
response = await _get_team_db_check(
|
||||
team_id=team_id, prisma_client=prisma_client
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
@ -580,6 +643,14 @@ async def get_team_object(
|
|||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# save to db access time
|
||||
# save to db access time
|
||||
_update_last_db_access_time(
|
||||
key=db_access_time_key,
|
||||
value=_response,
|
||||
last_db_access_time=last_db_access_time,
|
||||
)
|
||||
|
||||
return _response
|
||||
except Exception:
|
||||
raise Exception(
|
||||
|
@ -587,7 +658,7 @@ async def get_team_object(
|
|||
)
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
@log_db_metrics
|
||||
async def get_key_object(
|
||||
hashed_token: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -608,16 +679,16 @@ async def get_key_object(
|
|||
|
||||
# check if in cache
|
||||
key = hashed_token
|
||||
cached_team_obj: Optional[UserAPIKeyAuth] = None
|
||||
|
||||
if cached_team_obj is None:
|
||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
|
||||
cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
|
||||
key=key
|
||||
)
|
||||
|
||||
if cached_team_obj is not None:
|
||||
if isinstance(cached_team_obj, dict):
|
||||
return UserAPIKeyAuth(**cached_team_obj)
|
||||
elif isinstance(cached_team_obj, UserAPIKeyAuth):
|
||||
return cached_team_obj
|
||||
if cached_key_obj is not None:
|
||||
if isinstance(cached_key_obj, dict):
|
||||
return UserAPIKeyAuth(**cached_key_obj)
|
||||
elif isinstance(cached_key_obj, UserAPIKeyAuth):
|
||||
return cached_key_obj
|
||||
|
||||
if check_cache_only:
|
||||
raise Exception(
|
||||
|
@ -647,13 +718,55 @@ async def get_key_object(
|
|||
)
|
||||
|
||||
return _response
|
||||
except httpx.ConnectError as e:
|
||||
return await _handle_failed_db_connection_for_get_key_object(e=e)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
|
||||
)
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def _handle_failed_db_connection_for_get_key_object(
|
||||
e: Exception,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB
|
||||
|
||||
Use this if you don't want failed DB queries to block LLM API reqiests
|
||||
|
||||
Returns:
|
||||
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
||||
|
||||
Raises:
|
||||
- Orignal Exception in all other cases
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
# If this flag is on, requests failing to connect to the DB will be allowed
|
||||
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
|
||||
# log this as a DB failure on prometheus
|
||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type="get_key_object",
|
||||
error=e,
|
||||
duration=0.0,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
key_name="failed-to-connect-to-db",
|
||||
token="failed-to-connect-to-db",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
)
|
||||
else:
|
||||
# raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus
|
||||
raise e
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def get_org_object(
|
||||
org_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
|
|
@ -5,6 +5,9 @@ import json
|
|||
import os
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
@ -21,7 +24,7 @@ class LicenseCheck:
|
|||
def __init__(self) -> None:
|
||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
verbose_proxy_logger.debug("License Str value - {}".format(self.license_str))
|
||||
self.http_handler = HTTPHandler()
|
||||
self.http_handler = HTTPHandler(timeout=15)
|
||||
self.public_key = None
|
||||
self.read_public_key()
|
||||
|
||||
|
@ -44,23 +47,46 @@ class LicenseCheck:
|
|||
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
|
||||
|
||||
def _verify(self, license_str: str) -> bool:
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
|
||||
self.base_url, license_str
|
||||
)
|
||||
)
|
||||
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
try: # don't impact user, if call fails
|
||||
response = self.http_handler.get(url=url)
|
||||
num_retries = 3
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
response = self.http_handler.get(url=url)
|
||||
if response is None:
|
||||
raise Exception("No response from license server")
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
if i == num_retries - 1:
|
||||
raise
|
||||
|
||||
response.raise_for_status()
|
||||
if response is None:
|
||||
raise Exception("No response from license server")
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
premium = response_json["verify"]
|
||||
|
||||
assert isinstance(premium, bool)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
|
||||
license_str, premium
|
||||
)
|
||||
)
|
||||
return premium
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License via api. - {}".format(
|
||||
str(e)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
|
||||
license_str, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
@ -72,7 +98,7 @@ class LicenseCheck:
|
|||
"""
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - {}".format(
|
||||
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
|
||||
self.license_str
|
||||
)
|
||||
)
|
||||
|
|
|
@ -44,14 +44,8 @@ class RouteChecks:
|
|||
route in LiteLLMRoutes.info_routes.value
|
||||
): # check if user allowed to call an info route
|
||||
if route == "/key/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
key = query_params.get("key")
|
||||
if key is not None and hash_token(token=key) != api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="user not allowed to access this key's info",
|
||||
)
|
||||
# handled by function itself
|
||||
pass
|
||||
elif route == "/user/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
|
|
|
@ -58,7 +58,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
get_org_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
log_to_opentelemetry,
|
||||
log_db_metrics,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
_get_request_ip_address,
|
||||
|
@ -703,12 +703,17 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
)
|
||||
|
||||
if is_master_key_valid:
|
||||
_user_api_key_obj = UserAPIKeyAuth(
|
||||
api_key=master_key,
|
||||
_user_api_key_obj = _return_user_api_key_auth_obj(
|
||||
user_obj=None,
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_id=litellm_proxy_admin_name,
|
||||
api_key=master_key,
|
||||
parent_otel_span=parent_otel_span,
|
||||
**end_user_params,
|
||||
valid_token_dict={
|
||||
**end_user_params,
|
||||
"user_id": litellm_proxy_admin_name,
|
||||
},
|
||||
route=route,
|
||||
start_time=start_time,
|
||||
)
|
||||
await _cache_key_object(
|
||||
hashed_token=hash_token(master_key),
|
||||
|
@ -1127,11 +1132,13 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
api_key = valid_token.token
|
||||
|
||||
# Add hashed token to cache
|
||||
await _cache_key_object(
|
||||
hashed_token=api_key,
|
||||
user_api_key_obj=valid_token,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
asyncio.create_task(
|
||||
_cache_key_object(
|
||||
hashed_token=api_key,
|
||||
user_api_key_obj=valid_token,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
)
|
||||
|
||||
valid_token_dict = valid_token.model_dump(exclude_none=True)
|
||||
|
@ -1227,6 +1234,7 @@ def _return_user_api_key_auth_obj(
|
|||
valid_token_dict: dict,
|
||||
route: str,
|
||||
start_time: datetime,
|
||||
user_role: Optional[LitellmUserRoles] = None,
|
||||
) -> UserAPIKeyAuth:
|
||||
end_time = datetime.now()
|
||||
user_api_key_service_logger_obj.service_success_hook(
|
||||
|
@ -1238,7 +1246,7 @@ def _return_user_api_key_auth_obj(
|
|||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
retrieved_user_role = (
|
||||
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
||||
user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
||||
)
|
||||
|
||||
user_api_key_kwargs = {
|
||||
|
|
|
@ -3,18 +3,25 @@ import os
|
|||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
LITELLM_SALT_KEY = os.getenv("LITELLM_SALT_KEY", None)
|
||||
verbose_proxy_logger.debug(
|
||||
"LITELLM_SALT_KEY is None using master_key to encrypt/decrypt secrets stored in DB"
|
||||
)
|
||||
|
||||
def _get_salt_key():
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
salt_key = os.getenv("LITELLM_SALT_KEY", None)
|
||||
|
||||
if salt_key is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"LITELLM_SALT_KEY is None using master_key to encrypt/decrypt secrets stored in DB"
|
||||
)
|
||||
|
||||
salt_key = master_key
|
||||
|
||||
return salt_key
|
||||
|
||||
|
||||
def encrypt_value_helper(value: str):
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
signing_key = LITELLM_SALT_KEY
|
||||
if LITELLM_SALT_KEY is None:
|
||||
signing_key = master_key
|
||||
signing_key = _get_salt_key()
|
||||
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
|
@ -35,9 +42,7 @@ def encrypt_value_helper(value: str):
|
|||
def decrypt_value_helper(value: str):
|
||||
from litellm.proxy.proxy_server import master_key
|
||||
|
||||
signing_key = LITELLM_SALT_KEY
|
||||
if LITELLM_SALT_KEY is None:
|
||||
signing_key = master_key
|
||||
signing_key = _get_salt_key()
|
||||
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
|
|
138
litellm/proxy/db/log_db_metrics.py
Normal file
138
litellm/proxy/db/log_db_metrics.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
Handles logging DB success/failure to ServiceLogger()
|
||||
|
||||
ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
from litellm._service_logger import ServiceTypes
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def log_db_metrics(func):
|
||||
"""
|
||||
Decorator to log the duration of a DB related function to ServiceLogger()
|
||||
|
||||
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
|
||||
|
||||
When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
|
||||
|
||||
Args:
|
||||
func: The function to be decorated
|
||||
|
||||
Returns:
|
||||
Result from the decorated function
|
||||
|
||||
Raises:
|
||||
Exception: If the decorated function raises an exception
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
start_time: datetime = datetime.now()
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
end_time: datetime = datetime.now()
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
if "PROXY" not in func.__name__:
|
||||
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs.get("parent_otel_span", None),
|
||||
duration=(end_time - start_time).total_seconds(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"function_name": func.__name__,
|
||||
"function_kwargs": kwargs,
|
||||
"function_args": args,
|
||||
},
|
||||
)
|
||||
elif (
|
||||
# in litellm custom callbacks kwargs is passed as arg[0]
|
||||
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
|
||||
args is not None
|
||||
and len(args) > 0
|
||||
and isinstance(args[0], dict)
|
||||
):
|
||||
passed_kwargs = args[0]
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(
|
||||
kwargs=passed_kwargs
|
||||
)
|
||||
if parent_otel_span is not None:
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
|
||||
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.BATCH_WRITE_TO_DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=parent_otel_span,
|
||||
duration=0.0,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
# end of logging to otel
|
||||
return result
|
||||
except Exception as e:
|
||||
end_time: datetime = datetime.now()
|
||||
await _handle_logging_db_exception(
|
||||
e=e,
|
||||
func=func,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _is_exception_related_to_db(e: Exception) -> bool:
|
||||
"""
|
||||
Returns True if the exception is related to the DB
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
|
||||
|
||||
|
||||
async def _handle_logging_db_exception(
|
||||
e: Exception,
|
||||
func: Callable,
|
||||
kwargs: Dict,
|
||||
args: Tuple,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> None:
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
# don't log this as a DB Service Failure, if the DB did not raise an exception
|
||||
if _is_exception_related_to_db(e) is not True:
|
||||
return
|
||||
|
||||
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
error=e,
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs.get("parent_otel_span"),
|
||||
duration=(end_time - start_time).total_seconds(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"function_name": func.__name__,
|
||||
"function_kwargs": kwargs,
|
||||
"function_args": args,
|
||||
},
|
||||
)
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
@ -6,6 +7,7 @@ from starlette.datastructures import Headers
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.proxy._types import (
|
||||
AddTeamCallback,
|
||||
CommonProxyErrors,
|
||||
|
@ -16,11 +18,15 @@ from litellm.proxy._types import (
|
|||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import get_request_route
|
||||
from litellm.types.services import ServiceTypes
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
SupportedCacheControls,
|
||||
)
|
||||
|
||||
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||
|
||||
|
@ -471,7 +477,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
### END-USER SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.allowed_model_region is not None:
|
||||
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||
|
||||
start_time = time.time()
|
||||
## [Enterprise Only]
|
||||
# Add User-IP Address
|
||||
requester_ip_address = ""
|
||||
|
@ -539,6 +545,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
verbose_proxy_logger.debug(
|
||||
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
await service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.PROXY_PRE_CALL,
|
||||
duration=end_time - start_time,
|
||||
call_type="add_litellm_data_to_request",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import _duration_in_seconds
|
||||
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -734,13 +734,37 @@ async def info_key_fn(
|
|||
raise Exception(
|
||||
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
if key is None:
|
||||
key = user_api_key_dict.api_key
|
||||
key_info = await prisma_client.get_data(token=key)
|
||||
|
||||
# default to using Auth token if no key is passed in
|
||||
key = key or user_api_key_dict.api_key
|
||||
hashed_key: Optional[str] = key
|
||||
if key is not None:
|
||||
hashed_key = _hash_token_if_needed(token=key)
|
||||
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_key}, # type: ignore
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
if key_info is None:
|
||||
raise ProxyException(
|
||||
message="Key not found in database",
|
||||
type=ProxyErrorTypes.not_found_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
_can_user_query_key_info(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
key=key,
|
||||
key_info=key_info,
|
||||
)
|
||||
is not True
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"message": "No keys found"},
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You are not allowed to access this key's info. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
),
|
||||
)
|
||||
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
||||
try:
|
||||
|
@ -1540,6 +1564,27 @@ async def key_health(
|
|||
)
|
||||
|
||||
|
||||
def _can_user_query_key_info(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
key: Optional[str],
|
||||
key_info: LiteLLM_VerificationToken,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper to check if the user has access to the key's info
|
||||
"""
|
||||
if (
|
||||
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||
):
|
||||
return True
|
||||
elif user_api_key_dict.api_key == key:
|
||||
return True
|
||||
# user can query their own key info
|
||||
elif key_info.user_id == user_api_key_dict.user_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def test_key_logging(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
request: Request,
|
||||
|
@ -1599,7 +1644,9 @@ async def test_key_logging(
|
|||
details=f"Logging test failed: {str(e)}",
|
||||
)
|
||||
|
||||
await asyncio.sleep(1) # wait for callbacks to run
|
||||
await asyncio.sleep(
|
||||
2
|
||||
) # wait for callbacks to run, callbacks use batching so wait for the flush event
|
||||
|
||||
# Check if any logger exceptions were triggered
|
||||
log_contents = log_capture_string.getvalue()
|
||||
|
|
|
@ -1281,12 +1281,20 @@ async def list_team(
|
|||
where={"team_id": team.team_id}
|
||||
)
|
||||
|
||||
returned_responses.append(
|
||||
TeamListResponseObject(
|
||||
**team.model_dump(),
|
||||
team_memberships=_team_memberships,
|
||||
keys=keys,
|
||||
try:
|
||||
returned_responses.append(
|
||||
TeamListResponseObject(
|
||||
**team.model_dump(),
|
||||
team_memberships=_team_memberships,
|
||||
keys=keys,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
team_exception = """Invalid team object for team_id: {}. team_object={}.
|
||||
Error: {}
|
||||
""".format(
|
||||
team.team_id, team.model_dump(), str(e)
|
||||
)
|
||||
raise HTTPException(status_code=400, detail={"error": team_exception})
|
||||
|
||||
return returned_responses
|
||||
|
|
|
@ -694,6 +694,9 @@ def run_server( # noqa: PLR0915
|
|||
|
||||
import litellm
|
||||
|
||||
if detailed_debug is True:
|
||||
litellm._turn_on_debug()
|
||||
|
||||
# DO NOT DELETE - enables global variables to work across files
|
||||
from litellm.proxy.proxy_server import app # noqa
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
model_list:
|
||||
- model_name: gpt-4o
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/gpt-5
|
||||
model: openai/fake
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
|
||||
general_settings:
|
||||
alerting: ["slack"]
|
||||
alerting_threshold: 0.001
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["gcs_bucket"]
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ from litellm.proxy._types import *
|
|||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||
router as analytics_router,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
|
||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||
|
@ -1198,7 +1198,7 @@ async def update_cache( # noqa: PLR0915
|
|||
await _update_team_cache()
|
||||
|
||||
asyncio.create_task(
|
||||
user_api_key_cache.async_batch_set_cache(
|
||||
user_api_key_cache.async_set_cache_pipeline(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=parent_otel_span,
|
||||
|
@ -1257,7 +1257,7 @@ class ProxyConfig:
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
self.config: Dict[str, Any] = {}
|
||||
|
||||
def is_yaml(self, config_file_path: str) -> bool:
|
||||
if not os.path.isfile(config_file_path):
|
||||
|
@ -1271,9 +1271,6 @@ class ProxyConfig:
|
|||
) -> dict:
|
||||
"""
|
||||
Given a config file path, load the config from the file.
|
||||
|
||||
If `store_model_in_db` is True, then read the DB and update the config with the DB values.
|
||||
|
||||
Args:
|
||||
config_file_path (str): path to the config file
|
||||
Returns:
|
||||
|
@ -1299,40 +1296,6 @@ class ProxyConfig:
|
|||
"litellm_settings": {},
|
||||
}
|
||||
|
||||
## DB
|
||||
if prisma_client is not None and (
|
||||
general_settings.get("store_model_in_db", False) is True
|
||||
or store_model_in_db is True
|
||||
):
|
||||
_tasks = []
|
||||
keys = [
|
||||
"general_settings",
|
||||
"router_settings",
|
||||
"litellm_settings",
|
||||
"environment_variables",
|
||||
]
|
||||
for k in keys:
|
||||
response = prisma_client.get_generic_data(
|
||||
key="param_name", value=k, table_name="config"
|
||||
)
|
||||
_tasks.append(response)
|
||||
|
||||
responses = await asyncio.gather(*_tasks)
|
||||
for response in responses:
|
||||
if response is not None:
|
||||
param_name = getattr(response, "param_name", None)
|
||||
param_value = getattr(response, "param_value", None)
|
||||
if param_name is not None and param_value is not None:
|
||||
# check if param_name is already in the config
|
||||
if param_name in config:
|
||||
if isinstance(config[param_name], dict):
|
||||
config[param_name].update(param_value)
|
||||
else:
|
||||
config[param_name] = param_value
|
||||
else:
|
||||
# if it's not in the config - then add it
|
||||
config[param_name] = param_value
|
||||
|
||||
return config
|
||||
|
||||
async def save_config(self, new_config: dict):
|
||||
|
@ -1398,8 +1361,10 @@ class ProxyConfig:
|
|||
- for a given team id
|
||||
- return the relevant completion() call params
|
||||
"""
|
||||
|
||||
# load existing config
|
||||
config = await self.get_config()
|
||||
config = self.config
|
||||
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get("litellm_settings", {})
|
||||
all_teams_config = litellm_settings.get("default_team_settings", None)
|
||||
|
@ -1451,7 +1416,9 @@ class ProxyConfig:
|
|||
dict: config
|
||||
|
||||
"""
|
||||
global prisma_client, store_model_in_db
|
||||
# Load existing config
|
||||
|
||||
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
||||
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
|
||||
object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
|
||||
|
@ -1473,12 +1440,21 @@ class ProxyConfig:
|
|||
else:
|
||||
# default to file
|
||||
config = await self._get_config_from_file(config_file_path=config_file_path)
|
||||
## UPDATE CONFIG WITH DB
|
||||
if prisma_client is not None:
|
||||
config = await self._update_config_from_db(
|
||||
config=config,
|
||||
prisma_client=prisma_client,
|
||||
store_model_in_db=store_model_in_db,
|
||||
)
|
||||
|
||||
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||
printed_yaml = copy.deepcopy(config)
|
||||
printed_yaml.pop("environment_variables", None)
|
||||
|
||||
config = self._check_for_os_environ_vars(config=config)
|
||||
|
||||
self.config = config
|
||||
return config
|
||||
|
||||
async def load_config( # noqa: PLR0915
|
||||
|
@ -2290,6 +2266,55 @@ class ProxyConfig:
|
|||
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||
)
|
||||
|
||||
async def _update_config_from_db(
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
config: dict,
|
||||
store_model_in_db: Optional[bool],
|
||||
):
|
||||
|
||||
if store_model_in_db is not True:
|
||||
verbose_proxy_logger.info(
|
||||
"'store_model_in_db' is not True, skipping db updates"
|
||||
)
|
||||
return config
|
||||
|
||||
_tasks = []
|
||||
keys = [
|
||||
"general_settings",
|
||||
"router_settings",
|
||||
"litellm_settings",
|
||||
"environment_variables",
|
||||
]
|
||||
for k in keys:
|
||||
response = prisma_client.get_generic_data(
|
||||
key="param_name", value=k, table_name="config"
|
||||
)
|
||||
_tasks.append(response)
|
||||
|
||||
responses = await asyncio.gather(*_tasks)
|
||||
for response in responses:
|
||||
if response is not None:
|
||||
param_name = getattr(response, "param_name", None)
|
||||
verbose_proxy_logger.info(f"loading {param_name} settings from db")
|
||||
if param_name == "litellm_settings":
|
||||
verbose_proxy_logger.info(
|
||||
f"litellm_settings: {response.param_value}"
|
||||
)
|
||||
param_value = getattr(response, "param_value", None)
|
||||
if param_name is not None and param_value is not None:
|
||||
# check if param_name is already in the config
|
||||
if param_name in config:
|
||||
if isinstance(config[param_name], dict):
|
||||
config[param_name].update(param_value)
|
||||
else:
|
||||
config[param_name] = param_value
|
||||
else:
|
||||
# if it's not in the config - then add it
|
||||
config[param_name] = param_value
|
||||
|
||||
return config
|
||||
|
||||
async def add_deployment(
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
|
@ -2843,7 +2868,7 @@ class ProxyStartupEvent:
|
|||
|
||||
if (
|
||||
proxy_logging_obj is not None
|
||||
and proxy_logging_obj.slack_alerting_instance is not None
|
||||
and proxy_logging_obj.slack_alerting_instance.alerting is not None
|
||||
and prisma_client is not None
|
||||
):
|
||||
print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa
|
||||
|
@ -2891,7 +2916,7 @@ class ProxyStartupEvent:
|
|||
scheduler.start()
|
||||
|
||||
@classmethod
|
||||
def _setup_prisma_client(
|
||||
async def _setup_prisma_client(
|
||||
cls,
|
||||
database_url: Optional[str],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
|
@ -2910,11 +2935,15 @@ class ProxyStartupEvent:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
await prisma_client.connect()
|
||||
|
||||
## Add necessary views to proxy ##
|
||||
asyncio.create_task(
|
||||
prisma_client.check_view_exists()
|
||||
) # check if all necessary views exist. Don't block execution
|
||||
|
||||
# run a health check to ensure the DB is ready
|
||||
await prisma_client.health_check()
|
||||
return prisma_client
|
||||
|
||||
|
||||
|
@ -2931,12 +2960,21 @@ async def startup_event():
|
|||
# check if DATABASE_URL in environment - load from there
|
||||
if prisma_client is None:
|
||||
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
|
||||
prisma_client = ProxyStartupEvent._setup_prisma_client(
|
||||
prisma_client = await ProxyStartupEvent._setup_prisma_client(
|
||||
database_url=_db_url,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
## CHECK PREMIUM USER
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
|
||||
premium_user
|
||||
)
|
||||
)
|
||||
if premium_user is False:
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||
|
@ -2984,21 +3022,6 @@ async def startup_event():
|
|||
if isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
|
||||
## CHECK PREMIUM USER
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
|
||||
premium_user
|
||||
)
|
||||
)
|
||||
if premium_user is False:
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format(
|
||||
premium_user
|
||||
)
|
||||
)
|
||||
|
||||
ProxyStartupEvent._initialize_startup_logging(
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
|
@ -3021,9 +3044,6 @@ async def startup_event():
|
|||
prompt_injection_detection_obj.update_environment(router=llm_router)
|
||||
|
||||
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
|
||||
if prisma_client is not None:
|
||||
await prisma_client.connect()
|
||||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
ProxyStartupEvent._add_master_key_hash_to_db(
|
||||
master_key=master_key,
|
||||
|
@ -8723,7 +8743,7 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915
|
|||
if k == "alert_to_webhook_url":
|
||||
# check if slack is already enabled. if not, enable it
|
||||
if "alerting" not in _existing_settings:
|
||||
_existing_settings["alerting"].append("slack")
|
||||
_existing_settings = {"alerting": ["slack"]}
|
||||
elif isinstance(_existing_settings["alerting"], list):
|
||||
if "slack" not in _existing_settings["alerting"]:
|
||||
_existing_settings["alerting"].append("slack")
|
||||
|
|
|
@ -65,6 +65,7 @@ async def route_request(
|
|||
Common helper to route the request
|
||||
|
||||
"""
|
||||
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if "api_key" in data or "api_base" in data:
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
|
|
|
@ -55,10 +55,6 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy._types import (
|
||||
|
@ -77,6 +73,7 @@ from litellm.proxy.db.create_views import (
|
|||
create_missing_views,
|
||||
should_create_missing_views,
|
||||
)
|
||||
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
||||
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||
|
@ -137,83 +134,6 @@ def safe_deep_copy(data):
|
|||
return new_data
|
||||
|
||||
|
||||
def log_to_opentelemetry(func):
|
||||
"""
|
||||
Decorator to log the duration of a DB related function to ServiceLogger()
|
||||
|
||||
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time: datetime = datetime.now()
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
end_time: datetime = datetime.now()
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
if "PROXY" not in func.__name__:
|
||||
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs.get("parent_otel_span", None),
|
||||
duration=(end_time - start_time).total_seconds(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"function_name": func.__name__,
|
||||
"function_kwargs": kwargs,
|
||||
"function_args": args,
|
||||
},
|
||||
)
|
||||
elif (
|
||||
# in litellm custom callbacks kwargs is passed as arg[0]
|
||||
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
|
||||
args is not None
|
||||
and len(args) > 0
|
||||
and isinstance(args[0], dict)
|
||||
):
|
||||
passed_kwargs = args[0]
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(
|
||||
kwargs=passed_kwargs
|
||||
)
|
||||
if parent_otel_span is not None:
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
|
||||
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.BATCH_WRITE_TO_DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=parent_otel_span,
|
||||
duration=0.0,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
# end of logging to otel
|
||||
return result
|
||||
except Exception as e:
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
end_time: datetime = datetime.now()
|
||||
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
error=e,
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs.get("parent_otel_span"),
|
||||
duration=(end_time - start_time).total_seconds(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"function_name": func.__name__,
|
||||
"function_kwargs": kwargs,
|
||||
"function_args": args,
|
||||
},
|
||||
)
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class InternalUsageCache:
|
||||
def __init__(self, dual_cache: DualCache):
|
||||
self.dual_cache: DualCache = dual_cache
|
||||
|
@ -255,7 +175,7 @@ class InternalUsageCache:
|
|||
local_only: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
return await self.dual_cache.async_batch_set_cache(
|
||||
return await self.dual_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list,
|
||||
local_only=local_only,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
|
@ -1083,19 +1003,16 @@ class PrismaClient:
|
|||
proxy_logging_obj: ProxyLogging,
|
||||
http_client: Optional[Any] = None,
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||
)
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
self.iam_token_db_auth: Optional[bool] = str_to_bool(
|
||||
os.getenv("IAM_TOKEN_DB_AUTH")
|
||||
)
|
||||
verbose_proxy_logger.debug("Creating Prisma Client..")
|
||||
try:
|
||||
from prisma import Prisma # type: ignore
|
||||
except Exception:
|
||||
raise Exception("Unable to find Prisma binaries.")
|
||||
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
|
||||
if http_client is not None:
|
||||
self.db = PrismaWrapper(
|
||||
original_prisma=Prisma(http=http_client),
|
||||
|
@ -1114,7 +1031,7 @@ class PrismaClient:
|
|||
else False
|
||||
),
|
||||
) # Client to connect to Prisma db
|
||||
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
|
||||
verbose_proxy_logger.debug("Success - Created Prisma Client")
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
|
@ -1400,6 +1317,7 @@ class PrismaClient:
|
|||
|
||||
return
|
||||
|
||||
@log_db_metrics
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
|
@ -1465,7 +1383,7 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
@log_to_opentelemetry
|
||||
@log_db_metrics
|
||||
async def get_data( # noqa: PLR0915
|
||||
self,
|
||||
token: Optional[Union[str, list]] = None,
|
||||
|
@ -1506,9 +1424,7 @@ class PrismaClient:
|
|||
# check if plain text or hash
|
||||
if token is not None:
|
||||
if isinstance(token, str):
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
hashed_token = _hash_token_if_needed(token=token)
|
||||
verbose_proxy_logger.debug(
|
||||
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||
)
|
||||
|
@ -1575,8 +1491,7 @@ class PrismaClient:
|
|||
if token is not None:
|
||||
where_filter["token"] = {}
|
||||
if isinstance(token, str):
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
token = _hash_token_if_needed(token=token)
|
||||
where_filter["token"]["in"] = [token]
|
||||
elif isinstance(token, list):
|
||||
hashed_tokens = []
|
||||
|
@ -1712,9 +1627,7 @@ class PrismaClient:
|
|||
# check if plain text or hash
|
||||
if token is not None:
|
||||
if isinstance(token, str):
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
hashed_token = _hash_token_if_needed(token=token)
|
||||
verbose_proxy_logger.debug(
|
||||
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||
)
|
||||
|
@ -1994,8 +1907,7 @@ class PrismaClient:
|
|||
if token is not None:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
token = _hash_token_if_needed(token=token)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={"token": token}, # type: ignore
|
||||
|
@ -2347,11 +2259,7 @@ class PrismaClient:
|
|||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
sql_query = """
|
||||
SELECT 1
|
||||
FROM "LiteLLM_VerificationToken"
|
||||
LIMIT 1
|
||||
"""
|
||||
sql_query = "SELECT 1"
|
||||
|
||||
# Execute the raw query
|
||||
# The asterisk before `user_id_list` unpacks the list into separate arguments
|
||||
|
@ -2510,6 +2418,18 @@ def hash_token(token: str):
|
|||
return hashed_token
|
||||
|
||||
|
||||
def _hash_token_if_needed(token: str) -> str:
|
||||
"""
|
||||
Hash the token if it's a string and starts with "sk-"
|
||||
|
||||
Else return the token as is
|
||||
"""
|
||||
if token.startswith("sk-"):
|
||||
return hash_token(token=token)
|
||||
else:
|
||||
return token
|
||||
|
||||
|
||||
def _extract_from_regex(duration: str) -> Tuple[int, str]:
|
||||
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
|
||||
|
||||
|
|
|
@ -339,11 +339,7 @@ class Router:
|
|||
cache_config: Dict[str, Any] = {}
|
||||
|
||||
self.client_ttl = client_ttl
|
||||
if redis_url is not None or (
|
||||
redis_host is not None
|
||||
and redis_port is not None
|
||||
and redis_password is not None
|
||||
):
|
||||
if redis_url is not None or (redis_host is not None and redis_port is not None):
|
||||
cache_type = "redis"
|
||||
|
||||
if redis_url is not None:
|
||||
|
@ -556,6 +552,10 @@ class Router:
|
|||
|
||||
self.initialize_assistants_endpoint()
|
||||
|
||||
self.amoderation = self.factory_function(
|
||||
litellm.amoderation, call_type="moderation"
|
||||
)
|
||||
|
||||
def initialize_assistants_endpoint(self):
|
||||
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||
|
@ -585,6 +585,7 @@ class Router:
|
|||
def routing_strategy_init(
|
||||
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
||||
):
|
||||
verbose_router_logger.info(f"Routing strategy: {routing_strategy}")
|
||||
if (
|
||||
routing_strategy == RoutingStrategy.LEAST_BUSY.value
|
||||
or routing_strategy == RoutingStrategy.LEAST_BUSY
|
||||
|
@ -912,6 +913,7 @@ class Router:
|
|||
logging_obj=logging_obj,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
response = await _response
|
||||
|
||||
## CHECK CONTENT FILTER ERROR ##
|
||||
|
@ -1681,78 +1683,6 @@ class Router:
|
|||
)
|
||||
raise e
|
||||
|
||||
async def amoderation(self, model: str, input: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["input"] = input
|
||||
kwargs["original_function"] = self._amoderation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
send_llm_exception_alert(
|
||||
litellm_router_instance=self,
|
||||
request_kwargs=kwargs,
|
||||
error_traceback_str=traceback.format_exc(),
|
||||
original_exception=e,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _amoderation(self, model: str, input: str, **kwargs):
|
||||
model_name = None
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
input=input,
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
model_client = self._get_async_openai_model_client(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
timeout: Optional[Union[float, int]] = self._get_timeout(
|
||||
kwargs=kwargs,
|
||||
data=data,
|
||||
)
|
||||
|
||||
response = await litellm.amoderation(
|
||||
**{
|
||||
**data,
|
||||
"input": input,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
|
||||
)
|
||||
if model_name is not None:
|
||||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
async def arerank(self, model: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
@ -2608,20 +2538,46 @@ class Router:
|
|||
|
||||
return final_results
|
||||
|
||||
#### ASSISTANTS API ####
|
||||
#### PASSTHROUGH API ####
|
||||
|
||||
def factory_function(self, original_function: Callable):
|
||||
async def _pass_through_moderation_endpoint_factory(
|
||||
self,
|
||||
original_function: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
"model" in kwargs
|
||||
and self.get_model_list(model_name=kwargs["model"]) is not None
|
||||
):
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=kwargs["model"]
|
||||
)
|
||||
kwargs["model"] = deployment["litellm_params"]["model"]
|
||||
return await original_function(**kwargs)
|
||||
|
||||
def factory_function(
|
||||
self,
|
||||
original_function: Callable,
|
||||
call_type: Literal["assistants", "moderation"] = "assistants",
|
||||
):
|
||||
async def new_function(
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional["AsyncOpenAI"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return await self._pass_through_assistants_endpoint_factory(
|
||||
original_function=original_function,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
if call_type == "assistants":
|
||||
return await self._pass_through_assistants_endpoint_factory(
|
||||
original_function=original_function,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
elif call_type == "moderation":
|
||||
|
||||
return await self._pass_through_moderation_endpoint_factory( # type: ignore
|
||||
original_function=original_function,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return new_function
|
||||
|
||||
|
@ -2961,14 +2917,14 @@ class Router:
|
|||
raise
|
||||
|
||||
# decides how long to sleep before retry
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
retry_after = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
)
|
||||
# sleeps for the length of the timeout
|
||||
await asyncio.sleep(_timeout)
|
||||
|
||||
await asyncio.sleep(retry_after)
|
||||
for current_attempt in range(num_retries):
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
|
@ -3598,6 +3554,15 @@ class Router:
|
|||
# Catch all - if any exceptions default to cooling down
|
||||
return True
|
||||
|
||||
def _has_default_fallbacks(self) -> bool:
|
||||
if self.fallbacks is None:
|
||||
return False
|
||||
for fallback in self.fallbacks:
|
||||
if isinstance(fallback, dict):
|
||||
if "*" in fallback:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_raise_content_policy_error(
|
||||
self, model: str, response: ModelResponse, kwargs: dict
|
||||
) -> bool:
|
||||
|
@ -3614,6 +3579,7 @@ class Router:
|
|||
content_policy_fallbacks = kwargs.get(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
|
||||
if content_policy_fallbacks is not None:
|
||||
fallback_model_group = None
|
||||
|
@ -3624,6 +3590,8 @@ class Router:
|
|||
|
||||
if fallback_model_group is not None:
|
||||
return True
|
||||
elif self._has_default_fallbacks(): # default fallbacks set
|
||||
return True
|
||||
|
||||
verbose_router_logger.info(
|
||||
"Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(
|
||||
|
@ -4178,7 +4146,9 @@ class Router:
|
|||
model = _model
|
||||
|
||||
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
|
||||
model_info = litellm.get_model_info(model=model)
|
||||
model_info = litellm.get_model_info(
|
||||
model="{}/{}".format(custom_llm_provider, model)
|
||||
)
|
||||
|
||||
## CHECK USER SET MODEL INFO
|
||||
user_model_info = deployment.get("model_info", {})
|
||||
|
@ -4849,7 +4819,7 @@ class Router:
|
|||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
verbose_router_logger.error("An error occurs - {}".format(str(e)))
|
||||
verbose_router_logger.exception("An error occurs - {}".format(str(e)))
|
||||
|
||||
_litellm_params = deployment.get("litellm_params", {})
|
||||
model_id = deployment.get("model_info", {}).get("id", "")
|
||||
|
@ -5048,10 +5018,12 @@ class Router:
|
|||
)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError(
|
||||
"{}. You passed in model={}. There is no 'model_name' with this string ".format(
|
||||
RouterErrors.no_deployments_available.value, model
|
||||
)
|
||||
raise litellm.BadRequestError(
|
||||
message="You passed in model={}. There is no 'model_name' with this string ".format(
|
||||
model
|
||||
),
|
||||
model=model,
|
||||
llm_provider="",
|
||||
)
|
||||
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
|
|
|
@ -180,7 +180,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = float("inf")
|
||||
|
||||
if local_result is not None and local_result >= deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
|
@ -195,7 +194,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
deployment_rpm,
|
||||
local_result,
|
||||
),
|
||||
headers={"retry-after": 60}, # type: ignore
|
||||
headers={"retry-after": str(60)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
|
@ -221,7 +220,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
deployment_rpm,
|
||||
result,
|
||||
),
|
||||
headers={"retry-after": 60}, # type: ignore
|
||||
headers={"retry-after": str(60)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
|
|
|
@ -61,6 +61,24 @@ class PatternMatchRouter:
|
|||
# return f"^{regex}$"
|
||||
return re.escape(pattern).replace(r"\*", "(.*)")
|
||||
|
||||
def _return_pattern_matched_deployments(
|
||||
self, matched_pattern: Match, deployments: List[Dict]
|
||||
) -> List[Dict]:
|
||||
new_deployments = []
|
||||
for deployment in deployments:
|
||||
new_deployment = copy.deepcopy(deployment)
|
||||
new_deployment["litellm_params"]["model"] = (
|
||||
PatternMatchRouter.set_deployment_model_name(
|
||||
matched_pattern=matched_pattern,
|
||||
litellm_deployment_litellm_model=deployment["litellm_params"][
|
||||
"model"
|
||||
],
|
||||
)
|
||||
)
|
||||
new_deployments.append(new_deployment)
|
||||
|
||||
return new_deployments
|
||||
|
||||
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||
|
@ -79,8 +97,11 @@ class PatternMatchRouter:
|
|||
if request is None:
|
||||
return None
|
||||
for pattern, llm_deployments in self.patterns.items():
|
||||
if re.match(pattern, request):
|
||||
return llm_deployments
|
||||
pattern_match = re.match(pattern, request)
|
||||
if pattern_match:
|
||||
return self._return_pattern_matched_deployments(
|
||||
matched_pattern=pattern_match, deployments=llm_deployments
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
|
||||
|
||||
|
@ -96,12 +117,28 @@ class PatternMatchRouter:
|
|||
|
||||
E.g.:
|
||||
|
||||
Case 1:
|
||||
model_name: llmengine/* (can be any regex pattern or wildcard pattern)
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
|
||||
if model_name = "llmengine/foo" -> model = "openai/foo"
|
||||
|
||||
Case 2:
|
||||
model_name: llmengine/fo::*::static::*
|
||||
litellm_params:
|
||||
model: openai/fo::*::static::*
|
||||
|
||||
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz"
|
||||
|
||||
Case 3:
|
||||
model_name: *meta.llama3*
|
||||
litellm_params:
|
||||
model: bedrock/meta.llama3*
|
||||
|
||||
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b"
|
||||
"""
|
||||
|
||||
## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name
|
||||
if "*" not in litellm_deployment_litellm_model:
|
||||
return litellm_deployment_litellm_model
|
||||
|
@ -112,10 +149,9 @@ class PatternMatchRouter:
|
|||
dynamic_segments = matched_pattern.groups()
|
||||
|
||||
if len(dynamic_segments) > wildcard_count:
|
||||
raise ValueError(
|
||||
f"More wildcards in the deployment model name than the pattern. Wildcard count: {wildcard_count}, dynamic segments count: {len(dynamic_segments)}"
|
||||
)
|
||||
|
||||
return (
|
||||
matched_pattern.string
|
||||
) # default to the user input, if unable to map based on wildcards.
|
||||
# Replace the corresponding wildcards in the litellm model pattern with extracted segments
|
||||
for segment in dynamic_segments:
|
||||
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace(
|
||||
|
@ -165,12 +201,7 @@ class PatternMatchRouter:
|
|||
"""
|
||||
pattern_match = self.get_pattern(model, custom_llm_provider)
|
||||
if pattern_match:
|
||||
provider_deployments = []
|
||||
for deployment in pattern_match:
|
||||
dep = copy.deepcopy(deployment)
|
||||
dep["litellm_params"]["model"] = model
|
||||
provider_deployments.append(dep)
|
||||
return provider_deployments
|
||||
return pattern_match
|
||||
return []
|
||||
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue