mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge branch 'main' into CakeCrusher/developer_mode
This commit is contained in:
commit
779179b4da
317 changed files with 11086 additions and 3134 deletions
|
@ -72,6 +72,7 @@ jobs:
|
|||
pip install "jsonschema==4.22.0"
|
||||
pip install "pytest-xdist==3.6.1"
|
||||
pip install "websockets==10.4"
|
||||
pip uninstall posthog -y
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
@ -1517,6 +1518,117 @@ jobs:
|
|||
- store_test_results:
|
||||
path: test-results
|
||||
|
||||
proxy_multi_instance_tests:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
resource_class: xlarge
|
||||
working_directory: ~/project
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Docker CLI (In case it's not already installed)
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y docker-ce docker-ce-cli containerd.io
|
||||
- run:
|
||||
name: Install Python 3.9
|
||||
command: |
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh --output miniconda.sh
|
||||
bash miniconda.sh -b -p $HOME/miniconda
|
||||
export PATH="$HOME/miniconda/bin:$PATH"
|
||||
conda init bash
|
||||
source ~/.bashrc
|
||||
conda create -n myenv python=3.9 -y
|
||||
conda activate myenv
|
||||
python --version
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install aiohttp
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-mock==3.12.0"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||
- run:
|
||||
name: Run Docker container 1
|
||||
# intentionally give bad redis credentials here
|
||||
# the OTEL test - should get this as a trace
|
||||
command: |
|
||||
docker run -d \
|
||||
-p 4000:4000 \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||
-e REDIS_HOST=$REDIS_HOST \
|
||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||
-e REDIS_PORT=$REDIS_PORT \
|
||||
-e LITELLM_MASTER_KEY="sk-1234" \
|
||||
-e LITELLM_LICENSE=$LITELLM_LICENSE \
|
||||
-e USE_DDTRACE=True \
|
||||
-e DD_API_KEY=$DD_API_KEY \
|
||||
-e DD_SITE=$DD_SITE \
|
||||
--name my-app \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/multi_instance_simple_config.yaml:/app/config.yaml \
|
||||
my-app:latest \
|
||||
--config /app/config.yaml \
|
||||
--port 4000 \
|
||||
--detailed_debug \
|
||||
- run:
|
||||
name: Run Docker container 2
|
||||
command: |
|
||||
docker run -d \
|
||||
-p 4001:4001 \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||
-e REDIS_HOST=$REDIS_HOST \
|
||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||
-e REDIS_PORT=$REDIS_PORT \
|
||||
-e LITELLM_MASTER_KEY="sk-1234" \
|
||||
-e LITELLM_LICENSE=$LITELLM_LICENSE \
|
||||
-e USE_DDTRACE=True \
|
||||
-e DD_API_KEY=$DD_API_KEY \
|
||||
-e DD_SITE=$DD_SITE \
|
||||
--name my-app-2 \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/multi_instance_simple_config.yaml:/app/config.yaml \
|
||||
my-app:latest \
|
||||
--config /app/config.yaml \
|
||||
--port 4001 \
|
||||
--detailed_debug
|
||||
- run:
|
||||
name: Install curl and dockerize
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y curl
|
||||
sudo wget https://github.com/jwilder/dockerize/releases/download/v0.6.1/dockerize-linux-amd64-v0.6.1.tar.gz
|
||||
sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-v0.6.1.tar.gz
|
||||
sudo rm dockerize-linux-amd64-v0.6.1.tar.gz
|
||||
- run:
|
||||
name: Start outputting logs
|
||||
command: docker logs -f my-app
|
||||
background: true
|
||||
- run:
|
||||
name: Wait for instance 1 to be ready
|
||||
command: dockerize -wait http://localhost:4000 -timeout 5m
|
||||
- run:
|
||||
name: Wait for instance 2 to be ready
|
||||
command: dockerize -wait http://localhost:4001 -timeout 5m
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/multi_instance_e2e_tests -x --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout:
|
||||
120m
|
||||
# Clean up first container
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
|
||||
proxy_store_model_in_db_tests:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
|
@ -1552,6 +1664,7 @@ jobs:
|
|||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-mock==3.12.0"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "assemblyai==0.37.0"
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||
|
@ -1904,7 +2017,7 @@ jobs:
|
|||
circleci step halt
|
||||
fi
|
||||
- run:
|
||||
name: Trigger Github Action for new Docker Container + Trigger Stable Release Testing
|
||||
name: Trigger Github Action for new Docker Container + Trigger Load Testing
|
||||
command: |
|
||||
echo "Install TOML package."
|
||||
python3 -m pip install toml
|
||||
|
@ -1914,9 +2027,9 @@ jobs:
|
|||
-H "Accept: application/vnd.github.v3+json" \
|
||||
-H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
|
||||
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}\", \"commit_hash\":\"$CIRCLE_SHA1\"}}"
|
||||
echo "triggering stable release server for version ${VERSION} and commit ${CIRCLE_SHA1}"
|
||||
curl -X POST "https://proxyloadtester-production.up.railway.app/start/load/test?version=${VERSION}&commit_hash=${CIRCLE_SHA1}"
|
||||
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}-nightly\", \"commit_hash\":\"$CIRCLE_SHA1\"}}"
|
||||
echo "triggering load testing server for version ${VERSION} and commit ${CIRCLE_SHA1}"
|
||||
curl -X POST "https://proxyloadtester-production.up.railway.app/start/load/test?version=${VERSION}&commit_hash=${CIRCLE_SHA1}&release_type=nightly"
|
||||
|
||||
e2e_ui_testing:
|
||||
machine:
|
||||
|
@ -2171,6 +2284,12 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- proxy_multi_instance_tests:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- proxy_store_model_in_db_tests:
|
||||
filters:
|
||||
branches:
|
||||
|
@ -2302,6 +2421,7 @@ workflows:
|
|||
- installing_litellm_on_python
|
||||
- installing_litellm_on_python_3_13
|
||||
- proxy_logging_guardrails_model_info_tests
|
||||
- proxy_multi_instance_tests
|
||||
- proxy_store_model_in_db_tests
|
||||
- proxy_build_from_pip_tests
|
||||
- proxy_pass_through_endpoint_tests
|
||||
|
|
45
.github/workflows/interpret_load_test.py
vendored
45
.github/workflows/interpret_load_test.py
vendored
|
@ -52,6 +52,39 @@ def interpret_results(csv_file):
|
|||
return markdown_table
|
||||
|
||||
|
||||
def _get_docker_run_command_stable_release(release_version):
|
||||
return f"""
|
||||
\n\n
|
||||
## Docker Run LiteLLM Proxy
|
||||
|
||||
```
|
||||
docker run \\
|
||||
-e STORE_MODEL_IN_DB=True \\
|
||||
-p 4000:4000 \\
|
||||
ghcr.io/berriai/litellm_stable_release_branch-{release_version}
|
||||
"""
|
||||
|
||||
|
||||
def _get_docker_run_command(release_version):
|
||||
return f"""
|
||||
\n\n
|
||||
## Docker Run LiteLLM Proxy
|
||||
|
||||
```
|
||||
docker run \\
|
||||
-e STORE_MODEL_IN_DB=True \\
|
||||
-p 4000:4000 \\
|
||||
ghcr.io/berriai/litellm:main-{release_version}
|
||||
"""
|
||||
|
||||
|
||||
def get_docker_run_command(release_version):
|
||||
if "stable" in release_version:
|
||||
return _get_docker_run_command_stable_release(release_version)
|
||||
else:
|
||||
return _get_docker_run_command(release_version)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
csv_file = "load_test_stats.csv" # Change this to the path of your CSV file
|
||||
markdown_table = interpret_results(csv_file)
|
||||
|
@ -79,17 +112,7 @@ if __name__ == "__main__":
|
|||
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
||||
existing_release_body = latest_release.body[:start_index]
|
||||
|
||||
docker_run_command = f"""
|
||||
\n\n
|
||||
## Docker Run LiteLLM Proxy
|
||||
|
||||
```
|
||||
docker run \\
|
||||
-e STORE_MODEL_IN_DB=True \\
|
||||
-p 4000:4000 \\
|
||||
ghcr.io/berriai/litellm:main-{release_version}
|
||||
```
|
||||
"""
|
||||
docker_run_command = get_docker_run_command(release_version)
|
||||
print("docker run command: ", docker_run_command)
|
||||
|
||||
new_release_body = (
|
||||
|
|
172
cookbook/logging_observability/LiteLLM_Arize.ipynb
vendored
Normal file
172
cookbook/logging_observability/LiteLLM_Arize.ipynb
vendored
Normal file
|
@ -0,0 +1,172 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "4FbDOmcj2VkM"
|
||||
},
|
||||
"source": [
|
||||
"## Use LiteLLM with Arize\n",
|
||||
"https://docs.litellm.ai/docs/observability/arize_integration\n",
|
||||
"\n",
|
||||
"This method uses the litellm proxy to send the data to Arize. The callback is set in the litellm config below, instead of using OpenInference tracing."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "21W8Woog26Ns"
|
||||
},
|
||||
"source": [
|
||||
"## Install Dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "xrjKLBxhxu2L"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: litellm in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (1.54.1)\n",
|
||||
"Requirement already satisfied: aiohttp in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (3.11.10)\n",
|
||||
"Requirement already satisfied: click in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (8.1.7)\n",
|
||||
"Requirement already satisfied: httpx<0.28.0,>=0.23.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.27.2)\n",
|
||||
"Requirement already satisfied: importlib-metadata>=6.8.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (8.5.0)\n",
|
||||
"Requirement already satisfied: jinja2<4.0.0,>=3.1.2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (3.1.4)\n",
|
||||
"Requirement already satisfied: jsonschema<5.0.0,>=4.22.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (4.23.0)\n",
|
||||
"Requirement already satisfied: openai>=1.55.3 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (1.57.1)\n",
|
||||
"Requirement already satisfied: pydantic<3.0.0,>=2.0.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (2.10.3)\n",
|
||||
"Requirement already satisfied: python-dotenv>=0.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (1.0.1)\n",
|
||||
"Requirement already satisfied: requests<3.0.0,>=2.31.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (2.32.3)\n",
|
||||
"Requirement already satisfied: tiktoken>=0.7.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.7.0)\n",
|
||||
"Requirement already satisfied: tokenizers in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.21.0)\n",
|
||||
"Requirement already satisfied: anyio in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (4.7.0)\n",
|
||||
"Requirement already satisfied: certifi in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (2024.8.30)\n",
|
||||
"Requirement already satisfied: httpcore==1.* in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (1.0.7)\n",
|
||||
"Requirement already satisfied: idna in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (3.10)\n",
|
||||
"Requirement already satisfied: sniffio in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (1.3.1)\n",
|
||||
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.23.0->litellm) (0.14.0)\n",
|
||||
"Requirement already satisfied: zipp>=3.20 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from importlib-metadata>=6.8.0->litellm) (3.21.0)\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jinja2<4.0.0,>=3.1.2->litellm) (3.0.2)\n",
|
||||
"Requirement already satisfied: attrs>=22.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (24.2.0)\n",
|
||||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (2024.10.1)\n",
|
||||
"Requirement already satisfied: referencing>=0.28.4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (0.35.1)\n",
|
||||
"Requirement already satisfied: rpds-py>=0.7.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (0.22.3)\n",
|
||||
"Requirement already satisfied: distro<2,>=1.7.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (1.9.0)\n",
|
||||
"Requirement already satisfied: jiter<1,>=0.4.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (0.6.1)\n",
|
||||
"Requirement already satisfied: tqdm>4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (4.67.1)\n",
|
||||
"Requirement already satisfied: typing-extensions<5,>=4.11 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (4.12.2)\n",
|
||||
"Requirement already satisfied: annotated-types>=0.6.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from pydantic<3.0.0,>=2.0.0->litellm) (0.7.0)\n",
|
||||
"Requirement already satisfied: pydantic-core==2.27.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from pydantic<3.0.0,>=2.0.0->litellm) (2.27.1)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.31.0->litellm) (3.4.0)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.31.0->litellm) (2.0.7)\n",
|
||||
"Requirement already satisfied: regex>=2022.1.18 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from tiktoken>=0.7.0->litellm) (2024.11.6)\n",
|
||||
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (2.4.4)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.3.1)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.5.0)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (6.1.0)\n",
|
||||
"Requirement already satisfied: propcache>=0.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (0.2.1)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.18.3)\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from tokenizers->litellm) (0.26.5)\n",
|
||||
"Requirement already satisfied: filelock in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (3.16.1)\n",
|
||||
"Requirement already satisfied: fsspec>=2023.5.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (2024.10.0)\n",
|
||||
"Requirement already satisfied: packaging>=20.9 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (24.2)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (6.0.2)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install litellm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "jHEu-TjZ29PJ"
|
||||
},
|
||||
"source": [
|
||||
"## Set Env Variables"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"id": "QWd9rTysxsWO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import litellm\n",
|
||||
"import os\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"os.environ[\"ARIZE_SPACE_KEY\"] = getpass(\"Enter your Arize space key: \")\n",
|
||||
"os.environ[\"ARIZE_API_KEY\"] = getpass(\"Enter your Arize API key: \")\n",
|
||||
"os.environ['OPENAI_API_KEY']= getpass(\"Enter your OpenAI API key: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's run a completion call and see the traces in Arize"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello! Nice to meet you, OpenAI. How can I assist you today?\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# set arize as a callback, litellm will send the data to arize\n",
|
||||
"litellm.callbacks = [\"arize\"]\n",
|
||||
" \n",
|
||||
"# openai call\n",
|
||||
"response = litellm.completion(\n",
|
||||
" model=\"gpt-3.5-turbo\",\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"user\", \"content\": \"Hi 👋 - i'm openai\"}\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"print(response.choices[0].message.content)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
252
cookbook/logging_observability/LiteLLM_Proxy_Langfuse.ipynb
vendored
Normal file
252
cookbook/logging_observability/LiteLLM_Proxy_Langfuse.ipynb
vendored
Normal file
|
@ -0,0 +1,252 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## LLM Ops Stack - LiteLLM Proxy + Langfuse \n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to use LiteLLM Proxy with Langfuse \n",
|
||||
"- Use LiteLLM Proxy for calling 100+ LLMs in OpenAI format\n",
|
||||
"- Use Langfuse for viewing request / response traces \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"In this notebook we will setup LiteLLM Proxy to make requests to OpenAI, Anthropic, Bedrock and automatically log traces to Langfuse."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Setup LiteLLM Proxy\n",
|
||||
"\n",
|
||||
"### 1.1 Define .env variables \n",
|
||||
"Define .env variables on the container that litellm proxy is running on.\n",
|
||||
"```bash\n",
|
||||
"## LLM API Keys\n",
|
||||
"OPENAI_API_KEY=sk-proj-1234567890\n",
|
||||
"ANTHROPIC_API_KEY=sk-ant-api03-1234567890\n",
|
||||
"AWS_ACCESS_KEY_ID=1234567890\n",
|
||||
"AWS_SECRET_ACCESS_KEY=1234567890\n",
|
||||
"\n",
|
||||
"## Langfuse Logging \n",
|
||||
"LANGFUSE_PUBLIC_KEY=\"pk-lf-xxxx9\"\n",
|
||||
"LANGFUSE_SECRET_KEY=\"sk-lf-xxxx9\"\n",
|
||||
"LANGFUSE_HOST=\"https://us.cloud.langfuse.com\"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### 1.1 Setup LiteLLM Proxy Config yaml \n",
|
||||
"```yaml\n",
|
||||
"model_list:\n",
|
||||
" - model_name: gpt-4o\n",
|
||||
" litellm_params:\n",
|
||||
" model: openai/gpt-4o\n",
|
||||
" api_key: os.environ/OPENAI_API_KEY\n",
|
||||
" - model_name: claude-3-5-sonnet-20241022\n",
|
||||
" litellm_params:\n",
|
||||
" model: anthropic/claude-3-5-sonnet-20241022\n",
|
||||
" api_key: os.environ/ANTHROPIC_API_KEY\n",
|
||||
" - model_name: us.amazon.nova-micro-v1:0\n",
|
||||
" litellm_params:\n",
|
||||
" model: bedrock/us.amazon.nova-micro-v1:0\n",
|
||||
" aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID\n",
|
||||
" aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY\n",
|
||||
"\n",
|
||||
"litellm_settings:\n",
|
||||
" callbacks: [\"langfuse\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Make LLM Requests to LiteLLM Proxy\n",
|
||||
"\n",
|
||||
"Now we will make our first LLM request to LiteLLM Proxy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.1 Setup Client Side Variables to point to LiteLLM Proxy\n",
|
||||
"Set `LITELLM_PROXY_BASE_URL` to the base url of the LiteLLM Proxy and `LITELLM_VIRTUAL_KEY` to the virtual key you want to use for Authentication to LiteLLM Proxy. (Note: In this initial setup you can)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"LITELLM_PROXY_BASE_URL=\"http://0.0.0.0:4000\"\n",
|
||||
"LITELLM_VIRTUAL_KEY=\"sk-oXXRa1xxxxxxxxxxx\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletion(id='chatcmpl-B0sq6QkOKNMJ0dwP3x7OoMqk1jZcI', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Langfuse is a platform designed to monitor, observe, and troubleshoot AI and large language model (LLM) applications. It provides features that help developers gain insights into how their AI systems are performing, make debugging easier, and optimize the deployment of models. Langfuse allows for tracking of model interactions, collecting telemetry, and visualizing data, which is crucial for understanding the behavior of AI models in production environments. This kind of tool is particularly useful for developers working with language models who need to ensure reliability and efficiency in their applications.', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None))], created=1739550502, model='gpt-4o-2024-08-06', object='chat.completion', service_tier='default', system_fingerprint='fp_523b9b6e5f', usage=CompletionUsage(completion_tokens=109, prompt_tokens=13, total_tokens=122, completion_tokens_details=CompletionTokensDetails(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0), prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0)))"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"client = openai.OpenAI(\n",
|
||||
" api_key=LITELLM_VIRTUAL_KEY,\n",
|
||||
" base_url=LITELLM_PROXY_BASE_URL\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"gpt-4o\",\n",
|
||||
" messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"what is Langfuse?\"\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.3 View Traces on Langfuse\n",
|
||||
"LiteLLM will send the request / response, model, tokens (input + output), cost to Langfuse.\n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2.4 Call Anthropic, Bedrock models \n",
|
||||
"\n",
|
||||
"Now we can call `us.amazon.nova-micro-v1:0` and `claude-3-5-sonnet-20241022` models defined on your config.yaml both in the OpenAI request / response format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletion(id='chatcmpl-7756e509-e61f-4f5e-b5ae-b7a41013522a', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=\"Langfuse is an observability tool designed specifically for machine learning models and applications built with natural language processing (NLP) and large language models (LLMs). It focuses on providing detailed insights into how these models perform in real-world scenarios. Here are some key features and purposes of Langfuse:\\n\\n1. **Real-time Monitoring**: Langfuse allows developers to monitor the performance of their NLP and LLM applications in real time. This includes tracking the inputs and outputs of the models, as well as any errors or issues that arise during operation.\\n\\n2. **Error Tracking**: It helps in identifying and tracking errors in the models' outputs. By analyzing incorrect or unexpected responses, developers can pinpoint where and why errors occur, facilitating more effective debugging and improvement.\\n\\n3. **Performance Metrics**: Langfuse provides various performance metrics, such as latency, throughput, and error rates. These metrics help developers understand how well their models are performing under different conditions and workloads.\\n\\n4. **Traceability**: It offers detailed traceability of requests and responses, allowing developers to follow the path of a request through the system and see how it is processed by the model at each step.\\n\\n5. **User Feedback Integration**: Langfuse can integrate user feedback to provide context for model outputs. This helps in understanding how real users are interacting with the model and how its outputs align with user expectations.\\n\\n6. **Customizable Dashboards**: Users can create custom dashboards to visualize the data collected by Langfuse. These dashboards can be tailored to highlight the most important metrics and insights for a specific application or team.\\n\\n7. **Alerting and Notifications**: It can set up alerts for specific conditions or errors, notifying developers when something goes wrong or when performance metrics fall outside of acceptable ranges.\\n\\nBy providing comprehensive observability for NLP and LLM applications, Langfuse helps developers to build more reliable, accurate, and user-friendly models and services.\", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None))], created=1739554005, model='us.amazon.nova-micro-v1:0', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=380, prompt_tokens=5, total_tokens=385, completion_tokens_details=None, prompt_tokens_details=None))"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"client = openai.OpenAI(\n",
|
||||
" api_key=LITELLM_VIRTUAL_KEY,\n",
|
||||
" base_url=LITELLM_PROXY_BASE_URL\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"us.amazon.nova-micro-v1:0\",\n",
|
||||
" messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"what is Langfuse?\"\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Advanced - Set Langfuse Trace ID, Tags, Metadata \n",
|
||||
"\n",
|
||||
"Here is an example of how you can set Langfuse specific params on your client side request. See full list of supported langfuse params [here](https://docs.litellm.ai/docs/observability/langfuse_integration)\n",
|
||||
"\n",
|
||||
"You can view the logged trace of this request [here](https://us.cloud.langfuse.com/project/clvlhdfat0007vwb74m9lvfvi/traces/567890?timestamp=2025-02-14T17%3A30%3A26.709Z)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletion(id='chatcmpl-789babd5-c064-4939-9093-46e4cd2e208a', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=\"Langfuse is an observability platform designed specifically for monitoring and improving the performance of natural language processing (NLP) models and applications. It provides developers with tools to track, analyze, and optimize how their language models interact with users and handle natural language inputs.\\n\\nHere are some key features and benefits of Langfuse:\\n\\n1. **Real-Time Monitoring**: Langfuse allows developers to monitor their NLP applications in real time. This includes tracking user interactions, model responses, and overall performance metrics.\\n\\n2. **Error Tracking**: It helps in identifying and tracking errors in the model's responses. This can include incorrect, irrelevant, or unsafe outputs.\\n\\n3. **User Feedback Integration**: Langfuse enables the collection of user feedback directly within the platform. This feedback can be used to identify areas for improvement in the model's performance.\\n\\n4. **Performance Metrics**: The platform provides detailed metrics and analytics on model performance, including latency, throughput, and accuracy.\\n\\n5. **Alerts and Notifications**: Developers can set up alerts to notify them of any significant issues or anomalies in model performance.\\n\\n6. **Debugging Tools**: Langfuse offers tools to help developers debug and refine their models by providing insights into how the model processes different types of inputs.\\n\\n7. **Integration with Development Workflows**: It integrates seamlessly with various development environments and CI/CD pipelines, making it easier to incorporate observability into the development process.\\n\\n8. **Customizable Dashboards**: Users can create custom dashboards to visualize the data in a way that best suits their needs.\\n\\nLangfuse aims to help developers build more reliable, accurate, and user-friendly NLP applications by providing them with the tools to observe and improve how their models perform in real-world scenarios.\", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None))], created=1739554281, model='us.amazon.nova-micro-v1:0', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=346, prompt_tokens=5, total_tokens=351, completion_tokens_details=None, prompt_tokens_details=None))"
|
||||
]
|
||||
},
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"client = openai.OpenAI(\n",
|
||||
" api_key=LITELLM_VIRTUAL_KEY,\n",
|
||||
" base_url=LITELLM_PROXY_BASE_URL\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"us.amazon.nova-micro-v1:0\",\n",
|
||||
" messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"what is Langfuse?\"\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" extra_body={\n",
|
||||
" \"metadata\": {\n",
|
||||
" \"generation_id\": \"1234567890\",\n",
|
||||
" \"trace_id\": \"567890\",\n",
|
||||
" \"trace_user_id\": \"user_1234567890\",\n",
|
||||
" \"tags\": [\"tag1\", \"tag2\"]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
BIN
cookbook/logging_observability/litellm_proxy_langfuse.png
Normal file
BIN
cookbook/logging_observability/litellm_proxy_langfuse.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 308 KiB |
|
@ -11,9 +11,7 @@ FROM $LITELLM_BUILD_IMAGE AS builder
|
|||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk update && \
|
||||
apk add --no-cache gcc python3-dev musl-dev && \
|
||||
rm -rf /var/cache/apk/*
|
||||
RUN apk add --no-cache gcc python3-dev musl-dev
|
||||
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install build
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Local Debugging
|
||||
There's 2 ways to do local debugging - `litellm.set_verbose=True` and by passing in a custom function `completion(...logger_fn=<your_local_function>)`. Warning: Make sure to not use `set_verbose` in production. It logs API keys, which might end up in log files.
|
||||
There's 2 ways to do local debugging - `litellm._turn_on_debug()` and by passing in a custom function `completion(...logger_fn=<your_local_function>)`. Warning: Make sure to not use `_turn_on_debug()` in production. It logs API keys, which might end up in log files.
|
||||
|
||||
## Set Verbose
|
||||
|
||||
|
@ -8,7 +8,7 @@ This is good for getting print statements for everything litellm is doing.
|
|||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
litellm.set_verbose=True # 👈 this is the 1-line change you need to make
|
||||
litellm._turn_on_debug() # 👈 this is the 1-line change you need to make
|
||||
|
||||
## set ENV variables
|
||||
os.environ["OPENAI_API_KEY"] = "openai key"
|
||||
|
|
|
@ -19,6 +19,7 @@ Make an account on [Arize AI](https://app.arize.com/auth/login)
|
|||
## Quick Start
|
||||
Use just 2 lines of code, to instantly log your responses **across all providers** with arize
|
||||
|
||||
You can also use the instrumentor option instead of the callback, which you can find [here](https://docs.arize.com/arize/llm-tracing/tracing-integrations-auto/litellm).
|
||||
|
||||
```python
|
||||
litellm.callbacks = ["arize"]
|
||||
|
@ -28,7 +29,7 @@ import litellm
|
|||
import os
|
||||
|
||||
os.environ["ARIZE_SPACE_KEY"] = ""
|
||||
os.environ["ARIZE_API_KEY"] = "" # defaults to litellm-completion
|
||||
os.environ["ARIZE_API_KEY"] = ""
|
||||
|
||||
# LLM API Keys
|
||||
os.environ['OPENAI_API_KEY']=""
|
||||
|
|
|
@ -78,7 +78,7 @@ Following are the allowed fields in metadata, their types, and their description
|
|||
* `context: Optional[Union[dict, str]]` - This is the context used as information for the prompt. For RAG applications, this is the "retrieved" data. You may log context as a string or as an object (dictionary).
|
||||
* `expected_response: Optional[str]` - This is the reference response to compare against for evaluation purposes. This is useful for segmenting inference calls by expected response.
|
||||
* `user_query: Optional[str]` - This is the user's query. For conversational applications, this is the user's last message.
|
||||
|
||||
* `custom_attributes: Optional[dict]` - This is a dictionary of custom attributes. This is useful for additional information about the inference.
|
||||
|
||||
## Using a self hosted deployment of Athina
|
||||
|
||||
|
|
75
docs/my-website/docs/observability/phoenix_integration.md
Normal file
75
docs/my-website/docs/observability/phoenix_integration.md
Normal file
|
@ -0,0 +1,75 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
|
||||
# Phoenix OSS
|
||||
|
||||
Open source tracing and evaluation platform
|
||||
|
||||
:::tip
|
||||
|
||||
This is community maintained, Please make an issue if you run into a bug
|
||||
https://github.com/BerriAI/litellm
|
||||
|
||||
:::
|
||||
|
||||
|
||||
## Pre-Requisites
|
||||
Make an account on [Phoenix OSS](https://phoenix.arize.com)
|
||||
OR self-host your own instance of [Phoenix](https://docs.arize.com/phoenix/deployment)
|
||||
|
||||
## Quick Start
|
||||
Use just 2 lines of code, to instantly log your responses **across all providers** with Phoenix
|
||||
|
||||
You can also use the instrumentor option instead of the callback, which you can find [here](https://docs.arize.com/phoenix/tracing/integrations-tracing/litellm).
|
||||
|
||||
```python
|
||||
litellm.callbacks = ["arize_phoenix"]
|
||||
```
|
||||
```python
|
||||
import litellm
|
||||
import os
|
||||
|
||||
os.environ["PHOENIX_API_KEY"] = "" # Necessary only using Phoenix Cloud
|
||||
os.environ["PHOENIX_COLLECTOR_HTTP_ENDPOINT"] = "" # The URL of your Phoenix OSS instance
|
||||
# This defaults to https://app.phoenix.arize.com/v1/traces for Phoenix Cloud
|
||||
|
||||
# LLM API Keys
|
||||
os.environ['OPENAI_API_KEY']=""
|
||||
|
||||
# set arize as a callback, litellm will send the data to arize
|
||||
litellm.callbacks = ["phoenix"]
|
||||
|
||||
# openai call
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hi 👋 - i'm openai"}
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Using with LiteLLM Proxy
|
||||
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["arize_phoenix"]
|
||||
|
||||
environment_variables:
|
||||
PHOENIX_API_KEY: "d0*****"
|
||||
PHOENIX_COLLECTOR_ENDPOINT: "https://app.phoenix.arize.com/v1/traces" # OPTIONAL, for setting the GRPC endpoint
|
||||
PHOENIX_COLLECTOR_HTTP_ENDPOINT: "https://app.phoenix.arize.com/v1/traces" # OPTIONAL, for setting the HTTP endpoint
|
||||
```
|
||||
|
||||
## Support & Talk to Founders
|
||||
|
||||
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)
|
||||
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
|
||||
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
|
||||
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai
|
|
@ -12,6 +12,9 @@ Supports **ALL** Assembly AI Endpoints
|
|||
|
||||
[**See All Assembly AI Endpoints**](https://www.assemblyai.com/docs/api-reference)
|
||||
|
||||
|
||||
<iframe width="840" height="500" src="https://www.loom.com/embed/aac3f4d74592448992254bfa79b9f62d?sid=267cd0ab-d92b-42fa-b97a-9f385ef8930c" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
||||
|
||||
## Quick Start
|
||||
|
||||
Let's call the Assembly AI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts)
|
||||
|
@ -35,6 +38,8 @@ litellm
|
|||
Let's call the Assembly AI `/v2/transcripts` endpoint
|
||||
|
||||
```python
|
||||
import assemblyai as aai
|
||||
|
||||
LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key>
|
||||
LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/assemblyai" # <your-proxy-base-url>/assemblyai
|
||||
|
||||
|
@ -53,3 +58,28 @@ print(transcript)
|
|||
print(transcript.id)
|
||||
```
|
||||
|
||||
## Calling Assembly AI EU endpoints
|
||||
|
||||
If you want to send your request to the Assembly AI EU endpoint, you can do so by setting the `LITELLM_PROXY_BASE_URL` to `<your-proxy-base-url>/eu.assemblyai`
|
||||
|
||||
|
||||
```python
|
||||
import assemblyai as aai
|
||||
|
||||
LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key>
|
||||
LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/eu.assemblyai" # <your-proxy-base-url>/eu.assemblyai
|
||||
|
||||
aai.settings.api_key = f"Bearer {LITELLM_VIRTUAL_KEY}"
|
||||
aai.settings.base_url = LITELLM_PROXY_BASE_URL
|
||||
|
||||
# URL of the file to transcribe
|
||||
FILE_URL = "https://assembly.ai/wildfires.mp3"
|
||||
|
||||
# You can also transcribe a local file by passing in a file path
|
||||
# FILE_URL = './path/to/file.mp3'
|
||||
|
||||
transcriber = aai.Transcriber()
|
||||
transcript = transcriber.transcribe(FILE_URL)
|
||||
print(transcript)
|
||||
print(transcript.id)
|
||||
```
|
||||
|
|
|
@ -987,6 +987,106 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## [BETA] Citations API
|
||||
|
||||
Pass `citations: {"enabled": true}` to Anthropic, to get citations on your document responses.
|
||||
|
||||
Note: This interface is in BETA. If you have feedback on how citations should be returned, please [tell us here](https://github.com/BerriAI/litellm/issues/7970#issuecomment-2644437943)
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
resp = completion(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": "text/plain",
|
||||
"data": "The grass is green. The sky is blue.",
|
||||
},
|
||||
"title": "My Document",
|
||||
"context": "This is a trustworthy document.",
|
||||
"citations": {"enabled": True},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What color is the grass and sky?",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
citations = resp.choices[0].message.provider_specific_fields["citations"]
|
||||
|
||||
assert citations is not None
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-sonnet-20241022
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "anthropic-claude",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": "text/plain",
|
||||
"data": "The grass is green. The sky is blue.",
|
||||
},
|
||||
"title": "My Document",
|
||||
"context": "This is a trustworthy document.",
|
||||
"citations": {"enabled": True},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What color is the grass and sky?",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Usage - passing 'user_id' to Anthropic
|
||||
|
||||
LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.
|
||||
|
|
|
@ -7,7 +7,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
|
|||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). |
|
||||
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#bedrock-imported-models-deepseek) |
|
||||
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) |
|
||||
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
|
||||
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
|
||||
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
|
||||
|
@ -1277,13 +1277,83 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
https://some-api-url/models
|
||||
```
|
||||
|
||||
## Bedrock Imported Models (Deepseek)
|
||||
## Bedrock Imported Models (Deepseek, Deepseek R1)
|
||||
|
||||
### Deepseek R1
|
||||
|
||||
This is a separate route, as the chat template is different.
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Provider Route | `bedrock/deepseek_r1/{model_arn}` |
|
||||
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
response = completion(
|
||||
model="bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/deepseek_r1/{your-model-arn}
|
||||
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
|
||||
**1. Add to config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: DeepSeek-R1-Distill-Llama-70B
|
||||
litellm_params:
|
||||
model: bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
|
||||
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING at http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Deepseek (not R1)
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Provider Route | `bedrock/llama/{model_arn}` |
|
||||
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
|
||||
|
||||
|
||||
|
||||
Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec
|
||||
|
||||
|
||||
|
|
|
@ -688,7 +688,9 @@ response = litellm.completion(
|
|||
|-----------------------|--------------------------------------------------------|--------------------------------|
|
||||
| gemini-pro | `completion(model='gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-1.5-pro-latest | `completion(model='gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-pro-vision | `completion(model='gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-2.0-flash | `completion(model='gemini/gemini-2.0-flash', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-2.0-flash-exp | `completion(model='gemini/gemini-2.0-flash-exp', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-2.0-flash-lite-preview-02-05 | `completion(model='gemini/gemini-2.0-flash-lite-preview-02-05', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -64,71 +64,7 @@ All models listed here https://docs.perplexity.ai/docs/model-cards are supported
|
|||
|
||||
|
||||
|
||||
## Return citations
|
||||
|
||||
Perplexity supports returning citations via `return_citations=True`. [Perplexity Docs](https://docs.perplexity.ai/reference/post_chat_completions). Note: Perplexity has this feature in **closed beta**, so you need them to grant you access to get citations from their API.
|
||||
|
||||
If perplexity returns citations, LiteLLM will pass it straight through.
|
||||
|
||||
:::info
|
||||
|
||||
For passing more provider-specific, [go here](../completion/provider_specific_params.md)
|
||||
For more information about passing provider-specific parameters, [go here](../completion/provider_specific_params.md)
|
||||
:::
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['PERPLEXITYAI_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="perplexity/mistral-7b-instruct",
|
||||
messages=messages,
|
||||
return_citations=True
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add perplexity to config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "perplexity-model"
|
||||
litellm_params:
|
||||
model: "llama-3.1-sonar-small-128k-online"
|
||||
api_key: os.environ/PERPLEXITY_API_KEY
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "perplexity-model",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who won the world cup in 2022?"
|
||||
}
|
||||
],
|
||||
"return_citations": true
|
||||
}'
|
||||
```
|
||||
|
||||
[**Call w/ OpenAI SDK, Langchain, Instructor, etc.**](../proxy/user_keys.md#chatcompletions)
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
|
|
@ -488,12 +488,12 @@ router_settings:
|
|||
| SLACK_DAILY_REPORT_FREQUENCY | Frequency of daily Slack reports (e.g., daily, weekly)
|
||||
| SLACK_WEBHOOK_URL | Webhook URL for Slack integration
|
||||
| SMTP_HOST | Hostname for the SMTP server
|
||||
| SMTP_PASSWORD | Password for SMTP authentication
|
||||
| SMTP_PASSWORD | Password for SMTP authentication (do not set if SMTP does not require auth)
|
||||
| SMTP_PORT | Port number for SMTP server
|
||||
| SMTP_SENDER_EMAIL | Email address used as the sender in SMTP transactions
|
||||
| SMTP_SENDER_LOGO | Logo used in emails sent via SMTP
|
||||
| SMTP_TLS | Flag to enable or disable TLS for SMTP connections
|
||||
| SMTP_USERNAME | Username for SMTP authentication
|
||||
| SMTP_USERNAME | Username for SMTP authentication (do not set if SMTP does not require auth)
|
||||
| SPEND_LOGS_URL | URL for retrieving spend logs
|
||||
| SSL_CERTIFICATE | Path to the SSL certificate file
|
||||
| SSL_VERIFY | Flag to enable or disable SSL certificate verification
|
||||
|
|
|
@ -37,7 +37,7 @@ guardrails:
|
|||
- guardrail_name: aim-protected-app
|
||||
litellm_params:
|
||||
guardrail: aim
|
||||
mode: pre_call
|
||||
mode: pre_call # 'during_call' is also available
|
||||
api_key: os.environ/AIM_API_KEY
|
||||
api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost
|
||||
```
|
||||
|
|
|
@ -166,7 +166,7 @@ response = client.chat.completions.create(
|
|||
{"role": "user", "content": "what color is red"}
|
||||
],
|
||||
logit_bias={12481: 100},
|
||||
timeout=1
|
||||
extra_body={"timeout": 1} # 👈 KEY CHANGE
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
|
|
@ -163,10 +163,12 @@ scope: "litellm-proxy-admin ..."
|
|||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
user_id_jwt_field: "sub"
|
||||
team_ids_jwt_field: "groups"
|
||||
user_id_upsert: true # add user_id to the db if they don't exist
|
||||
enforce_team_based_model_access: true # don't allow users to access models unless the team has access
|
||||
```
|
||||
|
||||
This is assuming your token looks like this:
|
||||
|
@ -352,11 +354,11 @@ environment_variables:
|
|||
|
||||
### Example Token
|
||||
|
||||
```
|
||||
```bash
|
||||
{
|
||||
"aud": "api://LiteLLM_Proxy",
|
||||
"oid": "eec236bd-0135-4b28-9354-8fc4032d543e",
|
||||
"roles": ["litellm.api.consumer"]
|
||||
"roles": ["litellm.api.consumer"]
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -370,4 +372,68 @@ Supported internal roles:
|
|||
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
|
||||
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
|
||||
|
||||
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
|
||||
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
|
||||
|
||||
## [BETA] Control Model Access with Scopes
|
||||
|
||||
Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control.
|
||||
|
||||
### 1. Setup config.yaml with scope mappings.
|
||||
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-sonnet
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: gpt-3.5-turbo-testing
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id
|
||||
team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db
|
||||
scope_mappings:
|
||||
- scope: litellm.api.consumer
|
||||
models: ["anthropic-claude"]
|
||||
- scope: litellm.api.gpt_3_5_turbo
|
||||
models: ["gpt-3.5-turbo-testing"]
|
||||
enforce_scope_based_access: true # 👈 enforce scope-based access control
|
||||
enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy.
|
||||
```
|
||||
|
||||
#### Scope Mapping Spec
|
||||
|
||||
- `scope`: The scope to be used for the JWT token.
|
||||
- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported.
|
||||
|
||||
### 2. Create a JWT with the correct scopes.
|
||||
|
||||
Expected Token:
|
||||
|
||||
```bash
|
||||
{
|
||||
"scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"] # can be a list or a space-separated string
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Test the flow.
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer eyJhbGci...' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-testing",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, how'\''s it going 1234?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
|
@ -52,6 +52,7 @@ from litellm.constants import (
|
|||
open_ai_embedding_models,
|
||||
cohere_embedding_models,
|
||||
bedrock_embedding_models,
|
||||
known_tokenizer_config,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailItem
|
||||
from litellm.proxy._types import (
|
||||
|
@ -360,7 +361,15 @@ BEDROCK_CONVERSE_MODELS = [
|
|||
"meta.llama3-2-90b-instruct-v1:0",
|
||||
]
|
||||
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
||||
"cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21"
|
||||
"cohere",
|
||||
"anthropic",
|
||||
"mistral",
|
||||
"amazon",
|
||||
"meta",
|
||||
"llama",
|
||||
"ai21",
|
||||
"nova",
|
||||
"deepseek_r1",
|
||||
]
|
||||
####### COMPLETION MODELS ###################
|
||||
open_ai_chat_completion_models: List = []
|
||||
|
@ -863,6 +872,9 @@ from .llms.bedrock.common_utils import (
|
|||
from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import (
|
||||
AmazonAI21Config,
|
||||
)
|
||||
from .llms.bedrock.chat.invoke_transformations.amazon_nova_transformation import (
|
||||
AmazonInvokeNovaConfig,
|
||||
)
|
||||
from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import (
|
||||
AmazonAnthropicConfig,
|
||||
)
|
||||
|
|
|
@ -183,7 +183,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
|||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
|
||||
"init_redis_cluster: startup nodes are being initialized."
|
||||
)
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
|
@ -266,7 +266,9 @@ def get_redis_client(**env_overrides):
|
|||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
def get_redis_async_client(**env_overrides) -> async_redis.Redis:
|
||||
def get_redis_async_client(
|
||||
**env_overrides,
|
||||
) -> async_redis.Redis:
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
|
||||
|
|
|
@ -4,5 +4,6 @@ from .dual_cache import DualCache
|
|||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
|
|
|
@ -41,6 +41,7 @@ from .dual_cache import DualCache # noqa
|
|||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
|
||||
|
@ -158,14 +159,23 @@ class Cache:
|
|||
None. Cache is set as a litellm param
|
||||
"""
|
||||
if type == LiteLLMCacheType.REDIS:
|
||||
self.cache: BaseCache = RedisCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
startup_nodes=redis_startup_nodes,
|
||||
**kwargs,
|
||||
)
|
||||
if redis_startup_nodes:
|
||||
self.cache: BaseCache = RedisClusterCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
startup_nodes=redis_startup_nodes,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.cache = RedisCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
||||
self.cache = RedisSemanticCache(
|
||||
host=host,
|
||||
|
|
|
@ -14,7 +14,7 @@ import inspect
|
|||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
@ -26,15 +26,20 @@ from .base_cache import BaseCache
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.asyncio.client import Pipeline
|
||||
from redis.asyncio.cluster import ClusterPipeline
|
||||
|
||||
pipeline = Pipeline
|
||||
cluster_pipeline = ClusterPipeline
|
||||
async_redis_client = Redis
|
||||
async_redis_cluster_client = RedisCluster
|
||||
Span = _Span
|
||||
else:
|
||||
pipeline = Any
|
||||
cluster_pipeline = Any
|
||||
async_redis_client = Any
|
||||
async_redis_cluster_client = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
|
@ -122,7 +127,9 @@ class RedisCache(BaseCache):
|
|||
else:
|
||||
super().__init__() # defaults to 60s
|
||||
|
||||
def init_async_client(self):
|
||||
def init_async_client(
|
||||
self,
|
||||
) -> Union[async_redis_client, async_redis_cluster_client]:
|
||||
from .._redis import get_redis_async_client
|
||||
|
||||
return get_redis_async_client(
|
||||
|
@ -345,8 +352,14 @@ class RedisCache(BaseCache):
|
|||
)
|
||||
|
||||
async def _pipeline_helper(
|
||||
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
|
||||
self,
|
||||
pipe: Union[pipeline, cluster_pipeline],
|
||||
cache_list: List[Tuple[Any, Any]],
|
||||
ttl: Optional[float],
|
||||
) -> List:
|
||||
"""
|
||||
Helper function for executing a pipeline of set operations on Redis
|
||||
"""
|
||||
ttl = self.get_ttl(ttl=ttl)
|
||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||
for cache_key, cache_value in cache_list:
|
||||
|
@ -359,7 +372,11 @@ class RedisCache(BaseCache):
|
|||
_td: Optional[timedelta] = None
|
||||
if ttl is not None:
|
||||
_td = timedelta(seconds=ttl)
|
||||
pipe.set(cache_key, json_cache_value, ex=_td)
|
||||
pipe.set( # type: ignore
|
||||
name=cache_key,
|
||||
value=json_cache_value,
|
||||
ex=_td,
|
||||
)
|
||||
# Execute the pipeline and return the results.
|
||||
results = await pipe.execute()
|
||||
return results
|
||||
|
@ -373,9 +390,8 @@ class RedisCache(BaseCache):
|
|||
# don't waste a network request if there's nothing to set
|
||||
if len(cache_list) == 0:
|
||||
return
|
||||
from redis.asyncio import Redis
|
||||
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
|
||||
print_verbose(
|
||||
|
@ -384,7 +400,7 @@ class RedisCache(BaseCache):
|
|||
cache_value: Any = None
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
async with redis_client.pipeline(transaction=False) as pipe:
|
||||
results = await self._pipeline_helper(pipe, cache_list, ttl)
|
||||
|
||||
print_verbose(f"pipeline results: {results}")
|
||||
|
@ -730,7 +746,8 @@ class RedisCache(BaseCache):
|
|||
"""
|
||||
Use Redis for bulk read operations
|
||||
"""
|
||||
_redis_client = await self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
key_value_dict = {}
|
||||
start_time = time.time()
|
||||
try:
|
||||
|
@ -822,7 +839,8 @@ class RedisCache(BaseCache):
|
|||
raise e
|
||||
|
||||
async def ping(self) -> bool:
|
||||
_redis_client = self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
start_time = time.time()
|
||||
async with _redis_client as redis_client:
|
||||
print_verbose("Pinging Async Redis Cache")
|
||||
|
@ -858,7 +876,8 @@ class RedisCache(BaseCache):
|
|||
raise e
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
_redis_client = self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
# keys is a list, unpack it so it gets passed as individual elements to delete
|
||||
async with _redis_client as redis_client:
|
||||
await redis_client.delete(*keys)
|
||||
|
@ -881,7 +900,8 @@ class RedisCache(BaseCache):
|
|||
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
|
||||
|
||||
async def async_delete_cache(self, key: str):
|
||||
_redis_client = self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
# keys is str
|
||||
async with _redis_client as redis_client:
|
||||
await redis_client.delete(key)
|
||||
|
@ -936,7 +956,7 @@ class RedisCache(BaseCache):
|
|||
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
async with redis_client.pipeline(transaction=False) as pipe:
|
||||
results = await self._pipeline_increment_helper(
|
||||
pipe, increment_list
|
||||
)
|
||||
|
@ -991,7 +1011,8 @@ class RedisCache(BaseCache):
|
|||
Redis ref: https://redis.io/docs/latest/commands/ttl/
|
||||
"""
|
||||
try:
|
||||
_redis_client = await self.init_async_client()
|
||||
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
|
||||
_redis_client: Any = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
ttl = await redis_client.ttl(key)
|
||||
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
|
||||
|
|
44
litellm/caching/redis_cluster_cache.py
Normal file
44
litellm/caching/redis_cluster_cache.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Redis Cluster Cache implementation
|
||||
|
||||
Key differences:
|
||||
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.asyncio.client import Pipeline
|
||||
|
||||
pipeline = Pipeline
|
||||
async_redis_client = Redis
|
||||
Span = _Span
|
||||
else:
|
||||
pipeline = Any
|
||||
async_redis_client = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class RedisClusterCache(RedisCache):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.redis_cluster_client: Optional[RedisCluster] = None
|
||||
|
||||
def init_async_client(self):
|
||||
from redis.asyncio import RedisCluster
|
||||
|
||||
from .._redis import get_redis_async_client
|
||||
|
||||
if self.redis_cluster_client:
|
||||
return self.redis_cluster_client
|
||||
|
||||
_redis_client = get_redis_async_client(
|
||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||
)
|
||||
if isinstance(_redis_client, RedisCluster):
|
||||
self.redis_cluster_client = _redis_client
|
||||
return _redis_client
|
|
@ -335,6 +335,63 @@ bedrock_embedding_models: List = [
|
|||
"cohere.embed-multilingual-v3",
|
||||
]
|
||||
|
||||
known_tokenizer_config = {
|
||||
"mistralai/Mistral-7B-Instruct-v0.1": {
|
||||
"tokenizer": {
|
||||
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
||||
"tokenizer": {
|
||||
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
||||
"bos_token": "<|begin_of_text|>",
|
||||
"eos_token": "",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
"deepseek-r1/deepseek-r1-7b-instruct": {
|
||||
"tokenizer": {
|
||||
"add_bos_token": True,
|
||||
"add_eos_token": False,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|begin▁of▁sentence|>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
},
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|end▁of▁sentence|>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
},
|
||||
"legacy": True,
|
||||
"model_max_length": 16384,
|
||||
"pad_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|end▁of▁sentence|>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
},
|
||||
"sp_model_kwargs": {},
|
||||
"unk_token": None,
|
||||
"tokenizer_class": "LlamaTokenizerFast",
|
||||
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
OPENAI_FINISH_REASONS = ["stop", "length", "function_call", "content_filter", "null"]
|
||||
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = 60 # 1 minute
|
||||
|
|
|
@ -183,6 +183,9 @@ def create_fine_tuning_job(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get(
|
||||
"client", None
|
||||
), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
|
@ -388,6 +391,7 @@ def cancel_fine_tuning_job(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
|
@ -550,6 +554,7 @@ def list_fine_tuning_jobs(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
|
@ -701,6 +706,7 @@ def retrieve_fine_tuning_job(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
|
|
36
litellm/integrations/additional_logging_utils.py
Normal file
36
litellm/integrations/additional_logging_utils.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
"""
|
||||
Base class for Additional Logging Utils for CustomLoggers
|
||||
|
||||
- Health Check for the logging util
|
||||
- Get Request / Response Payload for the logging util
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
|
||||
|
||||
class AdditionalLoggingUtils(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetime],
|
||||
end_time_utc: Optional[datetime],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the request and response payload for a given `request_id`
|
||||
"""
|
||||
return None
|
|
@ -23,6 +23,7 @@ class AthinaLogger:
|
|||
"context",
|
||||
"expected_response",
|
||||
"user_query",
|
||||
"custom_attributes",
|
||||
]
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
"""
|
||||
Base class for health check integrations
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
|
||||
|
||||
class HealthCheckIntegration(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
pass
|
|
@ -38,14 +38,14 @@ from litellm.types.integrations.datadog import *
|
|||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
from ..base_health_check import HealthCheckIntegration
|
||||
from ..additional_logging_utils import AdditionalLoggingUtils
|
||||
|
||||
DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept
|
||||
|
||||
|
||||
class DataDogLogger(
|
||||
CustomBatchLogger,
|
||||
HealthCheckIntegration,
|
||||
AdditionalLoggingUtils,
|
||||
):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
|
@ -543,3 +543,13 @@ class DataDogLogger(
|
|||
status="unhealthy",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetimeObj],
|
||||
end_time_utc: Optional[datetimeObj],
|
||||
) -> Optional[dict]:
|
||||
raise NotImplementedError(
|
||||
"Datdog Integration for getting request/response payloads not implemented as yet"
|
||||
)
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
@ -20,7 +24,7 @@ GCS_DEFAULT_BATCH_SIZE = 2048
|
|||
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
||||
|
||||
|
||||
class GCSBucketLogger(GCSBucketBase):
|
||||
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
||||
def __init__(self, bucket_name: Optional[str] = None) -> None:
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
|
@ -39,6 +43,7 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
batch_size=self.batch_size,
|
||||
flush_interval=self.flush_interval,
|
||||
)
|
||||
AdditionalLoggingUtils.__init__(self)
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
|
@ -150,11 +155,16 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
"""
|
||||
Get the object name to use for the current payload
|
||||
"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
|
||||
if logging_payload.get("error_str", None) is not None:
|
||||
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
|
||||
object_name = self._generate_failure_object_name(
|
||||
request_date_str=current_date,
|
||||
)
|
||||
else:
|
||||
object_name = f"{current_date}/{response_obj.get('id', '')}"
|
||||
object_name = self._generate_success_object_name(
|
||||
request_date_str=current_date,
|
||||
response_id=response_obj.get("id", ""),
|
||||
)
|
||||
|
||||
# used for testing
|
||||
_litellm_params = kwargs.get("litellm_params", None) or {}
|
||||
|
@ -163,3 +173,65 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
object_name = _metadata["gcs_log_id"]
|
||||
|
||||
return object_name
|
||||
|
||||
async def get_request_response_payload(
|
||||
self,
|
||||
request_id: str,
|
||||
start_time_utc: Optional[datetime],
|
||||
end_time_utc: Optional[datetime],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the request and response payload for a given `request_id`
|
||||
Tries current day, next day, and previous day until it finds the payload
|
||||
"""
|
||||
if start_time_utc is None:
|
||||
raise ValueError(
|
||||
"start_time_utc is required for getting a payload from GCS Bucket"
|
||||
)
|
||||
|
||||
# Try current day, next day, and previous day
|
||||
dates_to_try = [
|
||||
start_time_utc,
|
||||
start_time_utc + timedelta(days=1),
|
||||
start_time_utc - timedelta(days=1),
|
||||
]
|
||||
date_str = None
|
||||
for date in dates_to_try:
|
||||
try:
|
||||
date_str = self._get_object_date_from_datetime(datetime_obj=date)
|
||||
object_name = self._generate_success_object_name(
|
||||
request_date_str=date_str,
|
||||
response_id=request_id,
|
||||
)
|
||||
encoded_object_name = quote(object_name, safe="")
|
||||
response = await self.download_gcs_object(encoded_object_name)
|
||||
|
||||
if response is not None:
|
||||
loaded_response = json.loads(response)
|
||||
return loaded_response
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to fetch payload for date {date_str}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _generate_success_object_name(
|
||||
self,
|
||||
request_date_str: str,
|
||||
response_id: str,
|
||||
) -> str:
|
||||
return f"{request_date_str}/{response_id}"
|
||||
|
||||
def _generate_failure_object_name(
|
||||
self,
|
||||
request_date_str: str,
|
||||
) -> str:
|
||||
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
|
||||
|
||||
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
|
||||
return datetime_obj.strftime("%Y-%m-%d")
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
raise NotImplementedError("GCS Bucket does not support health check")
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
import copy
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
|
@ -13,9 +14,16 @@ from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
|
|||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.integrations.langfuse import *
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
RerankResponse,
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingPromptManagementMetadata,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -150,19 +158,29 @@ class LangFuseLogger:
|
|||
|
||||
return metadata
|
||||
|
||||
def _old_log_event( # noqa: PLR0915
|
||||
def log_event_on_langfuse(
|
||||
self,
|
||||
kwargs,
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
user_id,
|
||||
print_verbose,
|
||||
level="DEFAULT",
|
||||
status_message=None,
|
||||
kwargs: dict,
|
||||
response_obj: Union[
|
||||
None,
|
||||
dict,
|
||||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
RerankResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
],
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
user_id: Optional[str] = None,
|
||||
level: str = "DEFAULT",
|
||||
status_message: Optional[str] = None,
|
||||
) -> dict:
|
||||
# Method definition
|
||||
|
||||
"""
|
||||
Logs a success or error event on Langfuse
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"Langfuse Logging - Enters logging function for model {kwargs}"
|
||||
|
@ -198,66 +216,13 @@ class LangFuseLogger:
|
|||
# if casting value to str fails don't block logging
|
||||
pass
|
||||
|
||||
# end of processing langfuse ########################
|
||||
if (
|
||||
level == "ERROR"
|
||||
and status_message is not None
|
||||
and isinstance(status_message, str)
|
||||
):
|
||||
input = prompt
|
||||
output = status_message
|
||||
elif response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
input = prompt
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.HttpxBinaryResponseContent
|
||||
):
|
||||
input = prompt
|
||||
output = "speech-output"
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.choices[0].text
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj["data"]
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TranscriptionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj["text"]
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.RerankResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.results
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "_arealtime"
|
||||
and response_obj is not None
|
||||
and isinstance(response_obj, list)
|
||||
):
|
||||
input = kwargs.get("input")
|
||||
output = response_obj
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "pass_through_endpoint"
|
||||
and response_obj is not None
|
||||
and isinstance(response_obj, dict)
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.get("response", "")
|
||||
input, output = self._get_langfuse_input_output_content(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
prompt=prompt,
|
||||
level=level,
|
||||
status_message=status_message,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}"
|
||||
)
|
||||
|
@ -265,31 +230,30 @@ class LangFuseLogger:
|
|||
generation_id = None
|
||||
if self._is_langfuse_v2():
|
||||
trace_id, generation_id = self._log_langfuse_v2(
|
||||
user_id,
|
||||
metadata,
|
||||
litellm_params,
|
||||
output,
|
||||
start_time,
|
||||
end_time,
|
||||
kwargs,
|
||||
optional_params,
|
||||
input,
|
||||
response_obj,
|
||||
level,
|
||||
print_verbose,
|
||||
litellm_call_id,
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
litellm_params=litellm_params,
|
||||
output=output,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
kwargs=kwargs,
|
||||
optional_params=optional_params,
|
||||
input=input,
|
||||
response_obj=response_obj,
|
||||
level=level,
|
||||
litellm_call_id=litellm_call_id,
|
||||
)
|
||||
elif response_obj is not None:
|
||||
self._log_langfuse_v1(
|
||||
user_id,
|
||||
metadata,
|
||||
output,
|
||||
start_time,
|
||||
end_time,
|
||||
kwargs,
|
||||
optional_params,
|
||||
input,
|
||||
response_obj,
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
output=output,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
kwargs=kwargs,
|
||||
optional_params=optional_params,
|
||||
input=input,
|
||||
response_obj=response_obj,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Langfuse Layer Logging - final response object: {response_obj}"
|
||||
|
@ -303,11 +267,108 @@ class LangFuseLogger:
|
|||
)
|
||||
return {"trace_id": None, "generation_id": None}
|
||||
|
||||
def _get_langfuse_input_output_content(
|
||||
self,
|
||||
kwargs: dict,
|
||||
response_obj: Union[
|
||||
None,
|
||||
dict,
|
||||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
RerankResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
],
|
||||
prompt: dict,
|
||||
level: str,
|
||||
status_message: Optional[str],
|
||||
) -> Tuple[Optional[dict], Optional[Union[str, dict, list]]]:
|
||||
"""
|
||||
Get the input and output content for Langfuse logging
|
||||
|
||||
Args:
|
||||
kwargs: The keyword arguments passed to the function
|
||||
response_obj: The response object returned by the function
|
||||
prompt: The prompt used to generate the response
|
||||
level: The level of the log message
|
||||
status_message: The status message of the log message
|
||||
|
||||
Returns:
|
||||
input: The input content for Langfuse logging
|
||||
output: The output content for Langfuse logging
|
||||
"""
|
||||
input = None
|
||||
output: Optional[Union[str, dict, List[Any]]] = None
|
||||
if (
|
||||
level == "ERROR"
|
||||
and status_message is not None
|
||||
and isinstance(status_message, str)
|
||||
):
|
||||
input = prompt
|
||||
output = status_message
|
||||
elif response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
input = prompt
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
input = prompt
|
||||
output = self._get_chat_content_for_langfuse(response_obj)
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.HttpxBinaryResponseContent
|
||||
):
|
||||
input = prompt
|
||||
output = "speech-output"
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = self._get_text_completion_content_for_langfuse(response_obj)
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.get("data", None)
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TranscriptionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.get("text", None)
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.RerankResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.results
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "_arealtime"
|
||||
and response_obj is not None
|
||||
and isinstance(response_obj, list)
|
||||
):
|
||||
input = kwargs.get("input")
|
||||
output = response_obj
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "pass_through_endpoint"
|
||||
and response_obj is not None
|
||||
and isinstance(response_obj, dict)
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.get("response", "")
|
||||
return input, output
|
||||
|
||||
async def _async_log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
|
||||
self, kwargs, response_obj, start_time, end_time, user_id
|
||||
):
|
||||
"""
|
||||
TODO: support async calls when langfuse is truly async
|
||||
Langfuse SDK uses a background thread to log events
|
||||
|
||||
This approach does not impact latency and runs in the background
|
||||
"""
|
||||
|
||||
def _is_langfuse_v2(self):
|
||||
|
@ -361,19 +422,18 @@ class LangFuseLogger:
|
|||
|
||||
def _log_langfuse_v2( # noqa: PLR0915
|
||||
self,
|
||||
user_id,
|
||||
metadata,
|
||||
litellm_params,
|
||||
output,
|
||||
start_time,
|
||||
end_time,
|
||||
kwargs,
|
||||
optional_params,
|
||||
input,
|
||||
user_id: Optional[str],
|
||||
metadata: dict,
|
||||
litellm_params: dict,
|
||||
output: Optional[Union[str, dict, list]],
|
||||
start_time: Optional[datetime],
|
||||
end_time: Optional[datetime],
|
||||
kwargs: dict,
|
||||
optional_params: dict,
|
||||
input: Optional[dict],
|
||||
response_obj,
|
||||
level,
|
||||
print_verbose,
|
||||
litellm_call_id,
|
||||
level: str,
|
||||
litellm_call_id: Optional[str],
|
||||
) -> tuple:
|
||||
verbose_logger.debug("Langfuse Layer Logging - logging to langfuse v2")
|
||||
|
||||
|
@ -657,6 +717,31 @@ class LangFuseLogger:
|
|||
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def _get_chat_content_for_langfuse(
|
||||
response_obj: ModelResponse,
|
||||
):
|
||||
"""
|
||||
Get the chat content for Langfuse logging
|
||||
"""
|
||||
if response_obj.choices and len(response_obj.choices) > 0:
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
return output
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_text_completion_content_for_langfuse(
|
||||
response_obj: TextCompletionResponse,
|
||||
):
|
||||
"""
|
||||
Get the text completion content for Langfuse logging
|
||||
"""
|
||||
if response_obj.choices and len(response_obj.choices) > 0:
|
||||
return response_obj.choices[0].text
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_langfuse_tags(
|
||||
standard_logging_object: Optional[StandardLoggingPayload],
|
||||
|
|
|
@ -247,13 +247,12 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
|||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
langfuse_logger_to_use._old_log_event(
|
||||
langfuse_logger_to_use.log_event_on_langfuse(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=None,
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
@ -271,12 +270,11 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
|||
)
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
langfuse_logger_to_use._old_log_event(
|
||||
langfuse_logger_to_use.log_event_on_langfuse(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
response_obj=None,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=None,
|
||||
status_message=standard_logging_object["error_str"],
|
||||
level="ERROR",
|
||||
kwargs=kwargs,
|
||||
|
|
|
@ -118,6 +118,7 @@ class PagerDutyAlerting(SlackAlerting):
|
|||
user_api_key_user_id=_meta.get("user_api_key_user_id"),
|
||||
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
|
||||
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
|
||||
user_api_key_user_email=_meta.get("user_api_key_user_email"),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -195,6 +196,7 @@ class PagerDutyAlerting(SlackAlerting):
|
|||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -423,6 +423,7 @@ class PrometheusLogger(CustomLogger):
|
|||
team=user_api_team,
|
||||
team_alias=user_api_team_alias,
|
||||
user=user_id,
|
||||
user_email=standard_logging_payload["metadata"]["user_api_key_user_email"],
|
||||
status_code="200",
|
||||
model=model,
|
||||
litellm_model_name=model,
|
||||
|
@ -806,6 +807,7 @@ class PrometheusLogger(CustomLogger):
|
|||
enum_values = UserAPIKeyLabelValues(
|
||||
end_user=user_api_key_dict.end_user_id,
|
||||
user=user_api_key_dict.user_id,
|
||||
user_email=user_api_key_dict.user_email,
|
||||
hashed_api_key=user_api_key_dict.api_key,
|
||||
api_key_alias=user_api_key_dict.key_alias,
|
||||
team=user_api_key_dict.team_id,
|
||||
|
@ -853,6 +855,7 @@ class PrometheusLogger(CustomLogger):
|
|||
team=user_api_key_dict.team_id,
|
||||
team_alias=user_api_key_dict.team_alias,
|
||||
user=user_api_key_dict.user_id,
|
||||
user_email=user_api_key_dict.user_email,
|
||||
status_code="200",
|
||||
)
|
||||
_labels = prometheus_label_factory(
|
||||
|
|
|
@ -223,6 +223,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
"Request Timeout Error" in error_str
|
||||
or "Request timed out" in error_str
|
||||
or "Timed out generating response" in error_str
|
||||
or "The read operation timed out" in error_str
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
|
||||
|
|
|
@ -121,21 +121,26 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
)
|
||||
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
|
||||
if request_type == "chat_completion":
|
||||
if model.startswith("meta/"):
|
||||
return litellm.VertexAILlama3Config().get_supported_openai_params()
|
||||
if model.startswith("mistral"):
|
||||
return litellm.MistralConfig().get_supported_openai_params(model=model)
|
||||
if model.startswith("codestral"):
|
||||
elif model.startswith("codestral"):
|
||||
return (
|
||||
litellm.CodestralTextCompletionConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
)
|
||||
if model.startswith("claude"):
|
||||
elif model.startswith("claude"):
|
||||
return litellm.VertexAIAnthropicConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
return litellm.VertexGeminiConfig().get_supported_openai_params(model=model)
|
||||
elif model.startswith("gemini"):
|
||||
return litellm.VertexGeminiConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.VertexAILlama3Config().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
|
|
|
@ -199,6 +199,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
dynamic_async_failure_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None,
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
_input: Optional[str] = messages # save original value of messages
|
||||
|
@ -271,6 +272,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"litellm_call_id": litellm_call_id,
|
||||
"input": _input,
|
||||
"litellm_params": litellm_params,
|
||||
"applied_guardrails": applied_guardrails,
|
||||
}
|
||||
|
||||
def process_dynamic_callbacks(self):
|
||||
|
@ -1247,13 +1249,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
if langfuse_logger_to_use is not None:
|
||||
_response = langfuse_logger_to_use._old_log_event(
|
||||
_response = langfuse_logger_to_use.log_event_on_langfuse(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if _response is not None and isinstance(_response, dict):
|
||||
_trace_id = _response.get("trace_id", None)
|
||||
|
@ -1957,12 +1958,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
standard_callback_dynamic_params=self.standard_callback_dynamic_params,
|
||||
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
|
||||
)
|
||||
_response = langfuse_logger_to_use._old_log_event(
|
||||
_response = langfuse_logger_to_use.log_event_on_langfuse(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
response_obj=None,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
status_message=str(exception),
|
||||
level="ERROR",
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -2854,6 +2854,7 @@ class StandardLoggingPayloadSetup:
|
|||
metadata: Optional[Dict[str, Any]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
prompt_integration: Optional[str] = None,
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
@ -2868,6 +2869,7 @@ class StandardLoggingPayloadSetup:
|
|||
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
|
||||
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
|
||||
"""
|
||||
|
||||
prompt_management_metadata: Optional[
|
||||
StandardLoggingPromptManagementMetadata
|
||||
] = None
|
||||
|
@ -2892,11 +2894,13 @@ class StandardLoggingPayloadSetup:
|
|||
user_api_key_org_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_team_alias=None,
|
||||
user_api_key_user_email=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
requester_metadata=None,
|
||||
user_api_key_end_user_id=None,
|
||||
prompt_management_metadata=prompt_management_metadata,
|
||||
applied_guardrails=applied_guardrails,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
@ -3195,6 +3199,7 @@ def get_standard_logging_object_payload(
|
|||
metadata=metadata,
|
||||
litellm_params=litellm_params,
|
||||
prompt_integration=kwargs.get("prompt_integration", None),
|
||||
applied_guardrails=kwargs.get("applied_guardrails", None),
|
||||
)
|
||||
|
||||
_request_body = proxy_server_request.get("body", {})
|
||||
|
@ -3324,12 +3329,14 @@ def get_standard_logging_metadata(
|
|||
user_api_key_team_id=None,
|
||||
user_api_key_org_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_user_email=None,
|
||||
user_api_key_team_alias=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
requester_metadata=None,
|
||||
user_api_key_end_user_id=None,
|
||||
prompt_management_metadata=None,
|
||||
applied_guardrails=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Union
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
@ -221,6 +222,27 @@ def _handle_invalid_parallel_tool_calls(
|
|||
return tool_calls
|
||||
|
||||
|
||||
def _parse_content_for_reasoning(
|
||||
message_text: Optional[str],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Parse the content for reasoning
|
||||
|
||||
Returns:
|
||||
- reasoning_content: The content of the reasoning
|
||||
- content: The content of the message
|
||||
"""
|
||||
if not message_text:
|
||||
return None, message_text
|
||||
|
||||
reasoning_match = re.match(r"<think>(.*?)</think>(.*)", message_text, re.DOTALL)
|
||||
|
||||
if reasoning_match:
|
||||
return reasoning_match.group(1), reasoning_match.group(2)
|
||||
|
||||
return None, message_text
|
||||
|
||||
|
||||
class LiteLLMResponseObjectHandler:
|
||||
|
||||
@staticmethod
|
||||
|
@ -432,8 +454,20 @@ def convert_to_model_response_object( # noqa: PLR0915
|
|||
for field in choice["message"].keys():
|
||||
if field not in message_keys:
|
||||
provider_specific_fields[field] = choice["message"][field]
|
||||
|
||||
# Handle reasoning models that display `reasoning_content` within `content`
|
||||
|
||||
reasoning_content, content = _parse_content_for_reasoning(
|
||||
choice["message"].get("content")
|
||||
)
|
||||
|
||||
if reasoning_content:
|
||||
provider_specific_fields["reasoning_content"] = (
|
||||
reasoning_content
|
||||
)
|
||||
|
||||
message = Message(
|
||||
content=choice["message"].get("content", None),
|
||||
content=content,
|
||||
role=choice["message"]["role"] or "assistant",
|
||||
function_call=choice["message"].get("function_call", None),
|
||||
tool_calls=tool_calls,
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Set, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
|
@ -85,6 +86,21 @@ class LoggingCallbackManager:
|
|||
callback=callback, parent_list=litellm._async_failure_callback
|
||||
)
|
||||
|
||||
def remove_callback_from_list_by_object(
|
||||
self, callback_list, obj
|
||||
):
|
||||
"""
|
||||
Remove callbacks that are methods of a particular object (e.g., router cleanup)
|
||||
"""
|
||||
if not isinstance(callback_list, list): # Not list -> do nothing
|
||||
return
|
||||
|
||||
remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj]
|
||||
|
||||
for c in remove_list:
|
||||
callback_list.remove(c)
|
||||
|
||||
|
||||
def _add_string_callback_to_list(
|
||||
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||
):
|
||||
|
@ -205,3 +221,36 @@ class LoggingCallbackManager:
|
|||
litellm._async_success_callback = []
|
||||
litellm._async_failure_callback = []
|
||||
litellm.callbacks = []
|
||||
|
||||
def _get_all_callbacks(self) -> List[Union[CustomLogger, Callable, str]]:
|
||||
"""
|
||||
Get all callbacks from litellm.callbacks, litellm.success_callback, litellm.failure_callback, litellm._async_success_callback, litellm._async_failure_callback
|
||||
"""
|
||||
return (
|
||||
litellm.callbacks
|
||||
+ litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
+ litellm._async_success_callback
|
||||
+ litellm._async_failure_callback
|
||||
)
|
||||
|
||||
def get_active_additional_logging_utils_from_custom_logger(
|
||||
self,
|
||||
) -> Set[AdditionalLoggingUtils]:
|
||||
"""
|
||||
Get all custom loggers that are instances of the given class type
|
||||
|
||||
Args:
|
||||
class_type: The class type to match against (e.g., AdditionalLoggingUtils)
|
||||
|
||||
Returns:
|
||||
Set[CustomLogger]: Set of custom loggers that are instances of the given class type
|
||||
"""
|
||||
all_callbacks = self._get_all_callbacks()
|
||||
matched_callbacks: Set[AdditionalLoggingUtils] = set()
|
||||
for callback in all_callbacks:
|
||||
if isinstance(callback, CustomLogger) and isinstance(
|
||||
callback, AdditionalLoggingUtils
|
||||
):
|
||||
matched_callbacks.add(callback)
|
||||
return matched_callbacks
|
||||
|
|
|
@ -325,26 +325,6 @@ def phind_codellama_pt(messages):
|
|||
return prompt
|
||||
|
||||
|
||||
known_tokenizer_config = {
|
||||
"mistralai/Mistral-7B-Instruct-v0.1": {
|
||||
"tokenizer": {
|
||||
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
||||
"tokenizer": {
|
||||
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
||||
"bos_token": "<|begin_of_text|>",
|
||||
"eos_token": "",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def hf_chat_template( # noqa: PLR0915
|
||||
model: str, messages: list, chat_template: Optional[Any] = None
|
||||
):
|
||||
|
@ -378,11 +358,11 @@ def hf_chat_template( # noqa: PLR0915
|
|||
else:
|
||||
return {"status": "failure"}
|
||||
|
||||
if model in known_tokenizer_config:
|
||||
tokenizer_config = known_tokenizer_config[model]
|
||||
if model in litellm.known_tokenizer_config:
|
||||
tokenizer_config = litellm.known_tokenizer_config[model]
|
||||
else:
|
||||
tokenizer_config = _get_tokenizer_config(model)
|
||||
known_tokenizer_config.update({model: tokenizer_config})
|
||||
litellm.known_tokenizer_config.update({model: tokenizer_config})
|
||||
|
||||
if (
|
||||
tokenizer_config["status"] == "failure"
|
||||
|
@ -475,6 +455,12 @@ def hf_chat_template( # noqa: PLR0915
|
|||
) # don't use verbose_logger.exception, if exception is raised
|
||||
|
||||
|
||||
def deepseek_r1_pt(messages):
|
||||
return hf_chat_template(
|
||||
model="deepseek-r1/deepseek-r1-7b-instruct", messages=messages
|
||||
)
|
||||
|
||||
|
||||
# Anthropic template
|
||||
def claude_2_1_pt(
|
||||
messages: list,
|
||||
|
@ -1421,6 +1407,8 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
)
|
||||
|
||||
user_content.append(_content_element)
|
||||
elif m.get("type", "") == "document":
|
||||
user_content.append(cast(AnthropicMessagesDocumentParam, m))
|
||||
elif isinstance(user_message_types_block["content"], str):
|
||||
_anthropic_content_text_element: AnthropicMessagesTextParam = {
|
||||
"type": "text",
|
||||
|
|
81
litellm/litellm_core_utils/sensitive_data_masker.py
Normal file
81
litellm/litellm_core_utils/sensitive_data_masker.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
|
||||
class SensitiveDataMasker:
|
||||
def __init__(
|
||||
self,
|
||||
sensitive_patterns: Optional[Set[str]] = None,
|
||||
visible_prefix: int = 4,
|
||||
visible_suffix: int = 4,
|
||||
mask_char: str = "*",
|
||||
):
|
||||
self.sensitive_patterns = sensitive_patterns or {
|
||||
"password",
|
||||
"secret",
|
||||
"key",
|
||||
"token",
|
||||
"auth",
|
||||
"credential",
|
||||
"access",
|
||||
"private",
|
||||
"certificate",
|
||||
}
|
||||
|
||||
self.visible_prefix = visible_prefix
|
||||
self.visible_suffix = visible_suffix
|
||||
self.mask_char = mask_char
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
if not value or len(str(value)) < (self.visible_prefix + self.visible_suffix):
|
||||
return value
|
||||
|
||||
value_str = str(value)
|
||||
masked_length = len(value_str) - (self.visible_prefix + self.visible_suffix)
|
||||
return f"{value_str[:self.visible_prefix]}{self.mask_char * masked_length}{value_str[-self.visible_suffix:]}"
|
||||
|
||||
def is_sensitive_key(self, key: str) -> bool:
|
||||
key_lower = str(key).lower()
|
||||
result = any(pattern in key_lower for pattern in self.sensitive_patterns)
|
||||
return result
|
||||
|
||||
def mask_dict(
|
||||
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
if depth >= max_depth:
|
||||
return data
|
||||
|
||||
masked_data: Dict[str, Any] = {}
|
||||
for k, v in data.items():
|
||||
try:
|
||||
if isinstance(v, dict):
|
||||
masked_data[k] = self.mask_dict(v, depth + 1)
|
||||
elif hasattr(v, "__dict__") and not isinstance(v, type):
|
||||
masked_data[k] = self.mask_dict(vars(v), depth + 1)
|
||||
elif self.is_sensitive_key(k):
|
||||
str_value = str(v) if v is not None else ""
|
||||
masked_data[k] = self._mask_value(str_value)
|
||||
else:
|
||||
masked_data[k] = (
|
||||
v if isinstance(v, (int, float, bool, str)) else str(v)
|
||||
)
|
||||
except Exception:
|
||||
masked_data[k] = "<unable to serialize>"
|
||||
|
||||
return masked_data
|
||||
|
||||
|
||||
# Usage example:
|
||||
"""
|
||||
masker = SensitiveDataMasker()
|
||||
data = {
|
||||
"api_key": "sk-1234567890abcdef",
|
||||
"redis_password": "very_secret_pass",
|
||||
"port": 6379
|
||||
}
|
||||
masked = masker.mask_dict(data)
|
||||
# Result: {
|
||||
# "api_key": "sk-1****cdef",
|
||||
# "redis_password": "very****pass",
|
||||
# "port": 6379
|
||||
# }
|
||||
"""
|
|
@ -809,7 +809,10 @@ class CustomStreamWrapper:
|
|||
if self.sent_first_chunk is False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
|
||||
if response_obj.get("provider_specific_fields") is not None:
|
||||
completion_obj["provider_specific_fields"] = response_obj[
|
||||
"provider_specific_fields"
|
||||
]
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
_index: Optional[int] = completion_obj.get("index")
|
||||
if _index is not None:
|
||||
|
|
|
@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
|
|||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
|
@ -506,6 +506,29 @@ class ModelResponseIterator:
|
|||
|
||||
return usage_block
|
||||
|
||||
def _content_block_delta_helper(self, chunk: dict):
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
provider_specific_fields = {}
|
||||
content_block = ContentBlockDelta(**chunk) # type: ignore
|
||||
self.content_blocks.append(content_block)
|
||||
if "text" in content_block["delta"]:
|
||||
text = content_block["delta"]["text"]
|
||||
elif "partial_json" in content_block["delta"]:
|
||||
tool_use = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": None,
|
||||
"arguments": content_block["delta"]["partial_json"],
|
||||
},
|
||||
"index": self.tool_index,
|
||||
}
|
||||
elif "citation" in content_block["delta"]:
|
||||
provider_specific_fields["citation"] = content_block["delta"]["citation"]
|
||||
|
||||
return text, tool_use, provider_specific_fields
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
type_chunk = chunk.get("type", "") or ""
|
||||
|
@ -515,6 +538,7 @@ class ModelResponseIterator:
|
|||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields: Dict[str, Any] = {}
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
if type_chunk == "content_block_delta":
|
||||
|
@ -522,20 +546,9 @@ class ModelResponseIterator:
|
|||
Anthropic content chunk
|
||||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||
"""
|
||||
content_block = ContentBlockDelta(**chunk) # type: ignore
|
||||
self.content_blocks.append(content_block)
|
||||
if "text" in content_block["delta"]:
|
||||
text = content_block["delta"]["text"]
|
||||
elif "partial_json" in content_block["delta"]:
|
||||
tool_use = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": None,
|
||||
"arguments": content_block["delta"]["partial_json"],
|
||||
},
|
||||
"index": self.tool_index,
|
||||
}
|
||||
text, tool_use, provider_specific_fields = (
|
||||
self._content_block_delta_helper(chunk=chunk)
|
||||
)
|
||||
elif type_chunk == "content_block_start":
|
||||
"""
|
||||
event: content_block_start
|
||||
|
@ -628,6 +641,9 @@ class ModelResponseIterator:
|
|||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=(
|
||||
provider_specific_fields if provider_specific_fields else None
|
||||
),
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
|
|
@ -70,7 +70,7 @@ class AnthropicConfig(BaseConfig):
|
|||
metadata: Optional[dict] = None,
|
||||
system: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
@ -628,6 +628,7 @@ class AnthropicConfig(BaseConfig):
|
|||
)
|
||||
else:
|
||||
text_content = ""
|
||||
citations: List[Any] = []
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
for idx, content in enumerate(completion_response["content"]):
|
||||
if content["type"] == "text":
|
||||
|
@ -645,10 +646,14 @@ class AnthropicConfig(BaseConfig):
|
|||
index=idx,
|
||||
)
|
||||
)
|
||||
## CITATIONS
|
||||
if content.get("citations", None) is not None:
|
||||
citations.append(content["citations"])
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
provider_specific_fields={"citations": citations},
|
||||
)
|
||||
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
|
|
|
@ -72,7 +72,7 @@ class AnthropicTextConfig(BaseConfig):
|
|||
top_k: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -5,10 +5,11 @@ import time
|
|||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.constants import DEFAULT_MAX_RETRIES
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
|
@ -98,14 +99,6 @@ class AzureOpenAIAssistantsAPIConfig:
|
|||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
# azure_client_params = {
|
||||
# "api_version": api_version,
|
||||
# "azure_endpoint": api_base,
|
||||
# "azure_deployment": model,
|
||||
# "http_client": litellm.client_session,
|
||||
# "max_retries": max_retries,
|
||||
# "timeout": timeout,
|
||||
# }
|
||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||
if azure_endpoint is not None:
|
||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
||||
|
@ -312,6 +305,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||
- call chat.completions.create by default
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
raw_response = await azure_client.chat.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
|
@ -320,6 +314,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
except APITimeoutError as e:
|
||||
end_time = time.time()
|
||||
time_delta = round(end_time - start_time, 2)
|
||||
e.message += f" - timeout value={timeout}, time taken={time_delta} seconds"
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -353,7 +352,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
status_code=422, message="Missing model or messages"
|
||||
)
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
max_retries = optional_params.pop("max_retries", None)
|
||||
if max_retries is None:
|
||||
max_retries = DEFAULT_MAX_RETRIES
|
||||
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
|
||||
|
||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||
|
@ -415,6 +416,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(
|
||||
|
@ -430,6 +432,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
max_retries=max_retries,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
|
@ -445,6 +448,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -553,6 +557,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
dynamic_params: bool,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
max_retries: int,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
convert_tool_call_to_json_mode: Optional[bool] = None,
|
||||
|
@ -560,12 +565,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
):
|
||||
response = None
|
||||
try:
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
@ -649,6 +648,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
except Exception as e:
|
||||
message = getattr(e, "message", str(e))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data["messages"],
|
||||
|
@ -659,7 +659,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
if hasattr(e, "status_code"):
|
||||
raise e
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
raise AzureOpenAIError(status_code=500, message=message)
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
|
@ -671,15 +671,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
max_retries: int,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
client=None,
|
||||
):
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
@ -742,6 +738,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
max_retries: int,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
client=None,
|
||||
|
@ -753,7 +750,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.aclient_session,
|
||||
"max_retries": data.pop("max_retries", 2),
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
|
@ -807,10 +804,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
message = getattr(e, "message", str(e))
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
status_code=status_code, message=message, headers=error_headers
|
||||
)
|
||||
|
||||
async def aembedding(
|
||||
|
|
|
@ -98,6 +98,7 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
"seed",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"prediction",
|
||||
]
|
||||
|
||||
def _is_response_format_supported_model(self, model: str) -> bool:
|
||||
|
@ -113,6 +114,17 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
|
||||
return False
|
||||
|
||||
def _is_response_format_supported_api_version(
|
||||
self, api_version_year: str, api_version_month: str
|
||||
) -> bool:
|
||||
"""
|
||||
- check if api_version is supported for response_format
|
||||
"""
|
||||
|
||||
is_supported = int(api_version_year) <= 2024 and int(api_version_month) >= 8
|
||||
|
||||
return is_supported
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
|
@ -171,13 +183,20 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
_is_response_format_supported_model = (
|
||||
self._is_response_format_supported_model(model)
|
||||
)
|
||||
should_convert_response_format_to_tool = (
|
||||
api_version_year <= "2024" and api_version_month < "08"
|
||||
) or not _is_response_format_supported_model
|
||||
|
||||
is_response_format_supported_api_version = (
|
||||
self._is_response_format_supported_api_version(
|
||||
api_version_year, api_version_month
|
||||
)
|
||||
)
|
||||
is_response_format_supported = (
|
||||
is_response_format_supported_api_version
|
||||
and _is_response_format_supported_model
|
||||
)
|
||||
optional_params = self._add_response_format_to_tools(
|
||||
optional_params=optional_params,
|
||||
value=value,
|
||||
should_convert_response_format_to_tool=should_convert_response_format_to_tool,
|
||||
is_response_format_supported=is_response_format_supported,
|
||||
)
|
||||
elif param == "tools" and isinstance(value, list):
|
||||
optional_params.setdefault("tools", [])
|
||||
|
|
|
@ -131,6 +131,7 @@ class AzureTextCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
return self.streaming(
|
||||
|
@ -236,17 +237,12 @@ class AzureTextCompletion(BaseLLM):
|
|||
timeout: Any,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Any,
|
||||
max_retries: int,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None, # this is the AsyncAzureOpenAI
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
|
|
@ -34,6 +34,17 @@ class BaseLLMModelInfo(ABC):
|
|||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
"""
|
||||
Returns the base model name from the given model name.
|
||||
|
||||
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
|
||||
This function will return `anthropic.claude-3-opus-20240229-v1:0`
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _dict_to_response_format_helper(
|
||||
response_format: dict, ref_template: Optional[str] = None
|
||||
|
|
|
@ -20,6 +20,7 @@ from pydantic import BaseModel
|
|||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionToolChoiceFunctionParam,
|
||||
|
@ -27,9 +28,6 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
|
@ -163,7 +161,7 @@ class BaseConfig(ABC):
|
|||
self,
|
||||
optional_params: dict,
|
||||
value: dict,
|
||||
should_convert_response_format_to_tool: bool,
|
||||
is_response_format_supported: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
@ -183,7 +181,8 @@ class BaseConfig(ABC):
|
|||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
|
||||
if json_schema and should_convert_response_format_to_tool:
|
||||
if json_schema and not is_response_format_supported:
|
||||
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
|
|
|
@ -52,6 +52,7 @@ class BaseAWSLLM:
|
|||
"aws_role_name",
|
||||
"aws_web_identity_token",
|
||||
"aws_sts_endpoint",
|
||||
"aws_bedrock_runtime_endpoint",
|
||||
]
|
||||
|
||||
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
|
||||
|
|
|
@ -33,14 +33,7 @@ from litellm.types.llms.openai import (
|
|||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import add_dummy_tool, has_tool_call_blocks
|
||||
|
||||
from ..common_utils import (
|
||||
AmazonBedrockGlobalConfig,
|
||||
BedrockError,
|
||||
get_bedrock_tool_name,
|
||||
)
|
||||
|
||||
global_config = AmazonBedrockGlobalConfig()
|
||||
all_global_regions = global_config.get_all_regions()
|
||||
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
|
||||
|
||||
|
||||
class AmazonConverseConfig(BaseConfig):
|
||||
|
@ -63,7 +56,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
topP: Optional[int] = None,
|
||||
topK: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
@ -104,7 +97,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
]
|
||||
|
||||
## Filter out 'cross-region' from model name
|
||||
base_model = self._get_base_model(model)
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
if (
|
||||
base_model.startswith("anthropic")
|
||||
|
@ -112,6 +105,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
or base_model.startswith("cohere")
|
||||
or base_model.startswith("meta.llama3-1")
|
||||
or base_model.startswith("meta.llama3-2")
|
||||
or base_model.startswith("meta.llama3-3")
|
||||
or base_model.startswith("amazon.nova")
|
||||
):
|
||||
supported_params.append("tools")
|
||||
|
@ -341,9 +335,9 @@ class AmazonConverseConfig(BaseConfig):
|
|||
if "top_k" in inference_params:
|
||||
inference_params["topK"] = inference_params.pop("top_k")
|
||||
return InferenceConfig(**inference_params)
|
||||
|
||||
|
||||
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
|
||||
base_model = self._get_base_model(model)
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
val_top_k = None
|
||||
if "topK" in inference_params:
|
||||
|
@ -352,11 +346,11 @@ class AmazonConverseConfig(BaseConfig):
|
|||
val_top_k = inference_params.pop("top_k")
|
||||
|
||||
if val_top_k:
|
||||
if (base_model.startswith("anthropic")):
|
||||
if base_model.startswith("anthropic"):
|
||||
return {"top_k": val_top_k}
|
||||
if base_model.startswith("amazon.nova"):
|
||||
return {'inferenceConfig': {"topK": val_top_k}}
|
||||
|
||||
return {"inferenceConfig": {"topK": val_top_k}}
|
||||
|
||||
return {}
|
||||
|
||||
def _transform_request_helper(
|
||||
|
@ -393,15 +387,25 @@ class AmazonConverseConfig(BaseConfig):
|
|||
) + ["top_k"]
|
||||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_guardrail_params = ["guardrailConfig"]
|
||||
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
|
||||
total_supported_params = (
|
||||
supported_converse_params
|
||||
+ supported_tool_call_params
|
||||
+ supported_guardrail_params
|
||||
)
|
||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||
|
||||
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
|
||||
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
|
||||
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}
|
||||
additional_request_params = {
|
||||
k: v for k, v in inference_params.items() if k not in total_supported_params
|
||||
}
|
||||
inference_params = {
|
||||
k: v for k, v in inference_params.items() if k in total_supported_params
|
||||
}
|
||||
|
||||
# Only set the topK value in for models that support it
|
||||
additional_request_params.update(self._handle_top_k_value(model, inference_params))
|
||||
additional_request_params.update(
|
||||
self._handle_top_k_value(model, inference_params)
|
||||
)
|
||||
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
|
@ -679,41 +683,6 @@ class AmazonConverseConfig(BaseConfig):
|
|||
|
||||
return model_response
|
||||
|
||||
def _supported_cross_region_inference_region(self) -> List[str]:
|
||||
"""
|
||||
Abbreviations of regions AWS Bedrock supports for cross region inference
|
||||
"""
|
||||
return ["us", "eu", "apac"]
|
||||
|
||||
def _get_base_model(self, model: str) -> str:
|
||||
"""
|
||||
Get the base model from the given model name.
|
||||
|
||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
|
||||
if model.startswith("bedrock/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("converse/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
potential_region = model.split(".", 1)[0]
|
||||
|
||||
alt_potential_region = model.split("/", 1)[
|
||||
0
|
||||
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
|
||||
|
||||
if potential_region in self._supported_cross_region_inference_region():
|
||||
return model.split(".", 1)[1]
|
||||
elif (
|
||||
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
|
||||
):
|
||||
return model.split("/", 1)[1]
|
||||
|
||||
return model
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Manages calling Bedrock's `/converse` API + `/invoke` API
|
||||
TODO: DELETE FILE. Bedrock LLM is no longer used. Goto `litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py`
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
@ -40,6 +40,9 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
parse_xml_params,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.anthropic.chat.handler import (
|
||||
ModelResponseIterator as AnthropicModelResponseIterator,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
|
@ -103,7 +106,7 @@ class AmazonCohereChatConfig:
|
|||
stop_sequences: Optional[str] = None,
|
||||
raw_prompting: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
@ -177,6 +180,7 @@ async def make_call(
|
|||
logging_obj: Logging,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
|
||||
):
|
||||
try:
|
||||
if client is None:
|
||||
|
@ -214,6 +218,14 @@ async def make_call(
|
|||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
elif bedrock_invoke_provider == "anthropic":
|
||||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||
model=model,
|
||||
sync_stream=False,
|
||||
)
|
||||
completion_stream = decoder.aiter_bytes(
|
||||
response.aiter_bytes(chunk_size=1024)
|
||||
)
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.aiter_bytes(
|
||||
|
@ -248,6 +260,7 @@ def make_sync_call(
|
|||
logging_obj: Logging,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
|
||||
):
|
||||
try:
|
||||
if client is None:
|
||||
|
@ -283,6 +296,12 @@ def make_sync_call(
|
|||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
elif bedrock_invoke_provider == "anthropic":
|
||||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||
model=model,
|
||||
sync_stream=True,
|
||||
)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
|
@ -1323,7 +1342,7 @@ class AWSEventStreamDecoder:
|
|||
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
######## bedrock.anthropic mappings ###############
|
||||
######## /bedrock/converse mappings ###############
|
||||
elif (
|
||||
"contentBlockIndex" in chunk_data
|
||||
or "stopReason" in chunk_data
|
||||
|
@ -1331,6 +1350,11 @@ class AWSEventStreamDecoder:
|
|||
or "trace" in chunk_data
|
||||
):
|
||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||
######### /bedrock/invoke nova mappings ###############
|
||||
elif "contentBlockDelta" in chunk_data:
|
||||
# when using /bedrock/invoke/nova, the chunk_data is nested under "contentBlockDelta"
|
||||
_chunk_data = chunk_data.get("contentBlockDelta", None)
|
||||
return self.converse_chunk_parser(chunk_data=_chunk_data)
|
||||
######## bedrock.mistral mappings ###############
|
||||
elif "outputs" in chunk_data:
|
||||
if (
|
||||
|
@ -1429,6 +1453,27 @@ class AWSEventStreamDecoder:
|
|||
return chunk.decode() # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
sync_stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
|
||||
|
||||
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=sync_stream,
|
||||
)
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
|
||||
|
||||
|
||||
class MockResponseIterator: # for returning ai21 streaming responses
|
||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||
self.model_response = model_response
|
||||
|
|
|
@ -46,7 +46,7 @@ class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
|
|||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -28,7 +28,7 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
|||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -28,7 +28,7 @@ class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
|
|||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -33,7 +33,7 @@ class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
|
|||
top_k: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{nova} models`
|
||||
|
||||
Inherits from `AmazonConverseConfig`
|
||||
|
||||
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
||||
"""
|
||||
Config for sending `nova` requests to `/bedrock/invoke/`
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_transformed_nova_request = super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
|
||||
**_transformed_nova_request
|
||||
)
|
||||
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
|
||||
bedrock_invoke_nova_request = self._filter_allowed_fields(
|
||||
_bedrock_invoke_nova_request
|
||||
)
|
||||
return bedrock_invoke_nova_request
|
||||
|
||||
def _filter_allowed_fields(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> dict:
|
||||
"""
|
||||
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
|
||||
"""
|
||||
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
|
||||
return {
|
||||
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
|
||||
}
|
||||
|
||||
def _remove_empty_system_messages(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> None:
|
||||
"""
|
||||
In-place remove empty `system` messages from the request.
|
||||
|
||||
/bedrock/invoke/ does not allow empty `system` messages.
|
||||
"""
|
||||
_system_message = bedrock_invoke_nova_request.get("system", None)
|
||||
if isinstance(_system_message, list) and len(_system_message) == 0:
|
||||
bedrock_invoke_nova_request.pop("system", None)
|
||||
return
|
|
@ -33,7 +33,7 @@ class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
|
|||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -34,7 +34,7 @@ class AmazonAnthropicConfig:
|
|||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -1,61 +1,34 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonAnthropicClaude3Config:
|
||||
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||
|
||||
Supported Params for the Amazon / Anthropic Claude 3 models:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens. Default is 4096
|
||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||
- `top_p` Optional (float) Use nucleus sampling.
|
||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
|
||||
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
anthropic_version: str = "bedrock-2023-05-31"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
|
@ -68,7 +41,13 @@ class AmazonAnthropicClaude3Config:
|
|||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
|
@ -83,3 +62,53 @@ class AmazonAnthropicClaude3Config:
|
|||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_anthropic_request = litellm.AnthropicConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
_anthropic_request.pop("model", None)
|
||||
if "anthropic_version" not in _anthropic_request:
|
||||
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||
|
||||
return _anthropic_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return litellm.AnthropicConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
|
|
@ -2,22 +2,19 @@ import copy
|
|||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
cohere_message_pt,
|
||||
construct_tool_use_system_prompt,
|
||||
contains_tag,
|
||||
custom_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
deepseek_r1_pt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
|
@ -91,7 +88,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
optional_params=optional_params,
|
||||
)
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
|
@ -129,15 +126,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
extra_headers = optional_params.pop("extra_headers", None)
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
extra_headers = optional_params.get("extra_headers", None)
|
||||
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.get("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.get("aws_session_token", None)
|
||||
aws_role_name = optional_params.get("aws_role_name", None)
|
||||
aws_session_name = optional_params.get("aws_session_name", None)
|
||||
aws_profile_name = optional_params.get("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
|
||||
aws_region_name = self._get_aws_region_name(optional_params)
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
|
@ -171,7 +168,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
|
||||
return dict(request.headers)
|
||||
|
||||
def transform_request( # noqa: PLR0915
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
|
@ -182,11 +179,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
custom_prompt_dict: dict = litellm_params.pop("custom_prompt_dict", None) or {}
|
||||
hf_model_name = litellm_params.get("hf_model_name", None)
|
||||
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
model=hf_model_name or model,
|
||||
messages=messages,
|
||||
provider=provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params = {
|
||||
|
@ -194,7 +195,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
for k, v in inference_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
json_schemas: dict = {}
|
||||
request_data: dict = {}
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
|
@ -223,57 +223,21 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
)
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_idx: list[int] = []
|
||||
system_messages: list[str] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system" and isinstance(
|
||||
message["content"], str
|
||||
):
|
||||
system_messages.append(message["content"])
|
||||
system_prompt_idx.append(idx)
|
||||
if len(system_prompt_idx) > 0:
|
||||
inference_params["system"] = "\n".join(system_messages)
|
||||
messages = [
|
||||
i for j, i in enumerate(messages) if j not in system_prompt_idx
|
||||
]
|
||||
# Format rest of message according to anthropic guidelines
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
||||
) # type: ignore
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicClaude3Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
## Handle Tool Calling
|
||||
if "tools" in inference_params:
|
||||
_is_function_call = True
|
||||
for tool in inference_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||
"parameters", None
|
||||
)
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=inference_params["tools"]
|
||||
)
|
||||
inference_params["system"] = (
|
||||
inference_params.get("system", "\n")
|
||||
+ tool_calling_system_prompt
|
||||
) # add the anthropic tool calling prompt to the system prompt
|
||||
inference_params.pop("tools")
|
||||
request_data = {"messages": messages, **inference_params}
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
|
@ -307,7 +271,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
elif provider == "meta" or provider == "llama":
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -347,6 +311,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
raise BedrockError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"bedrock invoke response % s",
|
||||
json.dumps(completion_response, indent=4, default=str),
|
||||
)
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
outputText: Optional[str] = None
|
||||
try:
|
||||
|
@ -359,71 +327,36 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
completion_response["generations"][0]["finish_reason"]
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
json_schemas: dict = {}
|
||||
_is_function_call = False
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
for tool in optional_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool[
|
||||
"function"
|
||||
].get("parameters", None)
|
||||
outputText = completion_response.get("content")[0].get("text", None)
|
||||
if outputText is not None and contains_tag(
|
||||
"invoke", outputText
|
||||
): # OUTPUT PARSE FUNCTION CALL
|
||||
function_name = extract_between_tags("tool_name", outputText)[0]
|
||||
function_arguments_str = extract_between_tags(
|
||||
"invoke", outputText
|
||||
)[0].strip()
|
||||
function_arguments_str = (
|
||||
f"<invoke>{function_arguments_str}</invoke>"
|
||||
)
|
||||
function_arguments = parse_xml_params(
|
||||
function_arguments_str,
|
||||
json_schema=json_schemas.get(
|
||||
function_name, None
|
||||
), # check if we have a json schema for this function name)
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = (
|
||||
outputText # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response.get("stop_reason", "")
|
||||
)
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=completion_response["usage"]["input_tokens"],
|
||||
completion_tokens=completion_response["usage"]["output_tokens"],
|
||||
total_tokens=completion_response["usage"]["input_tokens"]
|
||||
+ completion_response["usage"]["output_tokens"],
|
||||
)
|
||||
setattr(model_response, "usage", _usage)
|
||||
else:
|
||||
outputText = completion_response["completion"]
|
||||
|
||||
model_response.choices[0].finish_reason = completion_response[
|
||||
"stop_reason"
|
||||
]
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
elif provider == "meta" or provider == "llama":
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
outputText = completion_response["generation"]
|
||||
elif provider == "mistral":
|
||||
outputText = completion_response["outputs"][0]["text"]
|
||||
|
@ -536,6 +469,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
|
@ -569,6 +503,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
|
@ -594,10 +529,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 2 scenarions:
|
||||
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
handles 3 scenarions:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
@ -606,6 +546,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
# check if provider == "nova"
|
||||
if "nova" in model:
|
||||
return "nova"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
@ -640,16 +584,16 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
else:
|
||||
modelId = model
|
||||
|
||||
modelId = modelId.replace("invoke/", "", 1)
|
||||
if provider == "llama" and "llama/" in modelId:
|
||||
modelId = self._get_model_id_for_llama_like_model(modelId)
|
||||
|
||||
return modelId
|
||||
|
||||
def _get_aws_region_name(self, optional_params: dict) -> str:
|
||||
"""
|
||||
Get the AWS region name from the environment variables
|
||||
"""
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_region_name = optional_params.get("aws_region_name", None)
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
|
@ -725,6 +669,8 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
)
|
||||
elif provider == "cohere":
|
||||
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||
elif provider == "deepseek_r1":
|
||||
prompt = deepseek_r1_pt(messages=messages)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
|
|
|
@ -3,11 +3,12 @@ Common utilities used across bedrock chat/embedding/image generation
|
|||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
@ -310,3 +311,68 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
|
|||
response_tool_name
|
||||
]
|
||||
return response_tool_name
|
||||
|
||||
|
||||
class BedrockModelInfo(BaseLLMModelInfo):
|
||||
|
||||
global_config = AmazonBedrockGlobalConfig()
|
||||
all_global_regions = global_config.get_all_regions()
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
"""
|
||||
Get the base model from the given model name.
|
||||
|
||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
if model.startswith("bedrock/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("converse/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("invoke/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
potential_region = model.split(".", 1)[0]
|
||||
|
||||
alt_potential_region = model.split("/", 1)[
|
||||
0
|
||||
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
|
||||
|
||||
if (
|
||||
potential_region
|
||||
in BedrockModelInfo._supported_cross_region_inference_region()
|
||||
):
|
||||
return model.split(".", 1)[1]
|
||||
elif (
|
||||
alt_potential_region in BedrockModelInfo.all_global_regions
|
||||
and len(model.split("/", 1)) > 1
|
||||
):
|
||||
return model.split("/", 1)[1]
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _supported_cross_region_inference_region() -> List[str]:
|
||||
"""
|
||||
Abbreviations of regions AWS Bedrock supports for cross region inference
|
||||
"""
|
||||
return ["us", "eu", "apac"]
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
|
||||
"""
|
||||
Get the bedrock route for the given model.
|
||||
"""
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
if "invoke/" in model:
|
||||
return "invoke"
|
||||
elif "converse_like" in model:
|
||||
return "converse_like"
|
||||
elif "converse/" in model:
|
||||
return "converse"
|
||||
elif base_model in litellm.bedrock_converse_models:
|
||||
return "converse"
|
||||
return "invoke"
|
||||
|
|
|
@ -27,7 +27,7 @@ class AmazonTitanG1Config:
|
|||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -33,7 +33,7 @@ class AmazonTitanV2Config:
|
|||
def __init__(
|
||||
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -49,7 +49,7 @@ class AmazonStabilityConfig:
|
|||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -45,7 +45,7 @@ class ClarifaiConfig(BaseConfig):
|
|||
temperature: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -44,7 +44,7 @@ class CloudflareChatConfig(BaseConfig):
|
|||
max_tokens: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -104,7 +104,7 @@ class CohereChatConfig(BaseConfig):
|
|||
tool_results: Optional[list] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -86,7 +86,7 @@ class CohereTextConfig(BaseConfig):
|
|||
return_likelihoods: Optional[str] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
@ -179,6 +180,7 @@ class AsyncHTTPHandler:
|
|||
stream: bool = False,
|
||||
logging_obj: Optional[LiteLLMLoggingObject] = None,
|
||||
):
|
||||
start_time = time.time()
|
||||
try:
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
@ -207,6 +209,8 @@ class AsyncHTTPHandler:
|
|||
finally:
|
||||
await new_client.aclose()
|
||||
except httpx.TimeoutException as e:
|
||||
end_time = time.time()
|
||||
time_delta = round(end_time - start_time, 3)
|
||||
headers = {}
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_response is not None:
|
||||
|
@ -214,7 +218,7 @@ class AsyncHTTPHandler:
|
|||
headers["response_headers-{}".format(key)] = value
|
||||
|
||||
raise litellm.Timeout(
|
||||
message=f"Connection timed out after {timeout} seconds.",
|
||||
message=f"Connection timed out. Timeout passed={timeout}, time taken={time_delta} seconds",
|
||||
model="default-model-name",
|
||||
llm_provider="litellm-httpx-handler",
|
||||
headers=headers,
|
||||
|
|
|
@ -37,7 +37,7 @@ class DatabricksConfig(OpenAILikeChatConfig):
|
|||
stop: Optional[Union[List[str], str]] = None,
|
||||
n: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -16,7 +16,7 @@ class DatabricksEmbeddingConfig:
|
|||
)
|
||||
|
||||
def __init__(self, instruction: Optional[str] = None) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -145,7 +145,7 @@ class AlephAlphaConfig:
|
|||
contextual_control_threshold: Optional[int] = None,
|
||||
control_log_additive: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -63,7 +63,7 @@ class PalmConfig:
|
|||
top_p: Optional[float] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -57,7 +57,7 @@ class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
|
|||
candidate_count: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -77,7 +77,7 @@ class HuggingfaceChatConfig(BaseConfig):
|
|||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -20,6 +20,15 @@ from .common_utils import InfinityError
|
|||
|
||||
|
||||
class InfinityRerankConfig(CohereRerankConfig):
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Infinity rerank")
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/rerank"):
|
||||
api_base = f"{api_base}/rerank"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
|
|
|
@ -21,7 +21,7 @@ class JinaAIEmbeddingConfig:
|
|||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -18,7 +18,7 @@ class LmStudioEmbeddingConfig:
|
|||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -33,7 +33,7 @@ class MaritalkConfig(OpenAIGPTConfig):
|
|||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -78,7 +78,7 @@ class NLPCloudConfig(BaseConfig):
|
|||
num_beams: Optional[int] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -32,7 +32,7 @@ class NvidiaNimEmbeddingConfig:
|
|||
input_type: Optional[str] = None,
|
||||
truncate: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
@ -58,7 +58,7 @@ class NvidiaNimEmbeddingConfig:
|
|||
def get_supported_openai_params(
|
||||
self,
|
||||
):
|
||||
return ["encoding_format", "user"]
|
||||
return ["encoding_format", "user", "dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
|
@ -73,6 +73,8 @@ class NvidiaNimEmbeddingConfig:
|
|||
optional_params["extra_body"].update({"input_type": v})
|
||||
elif k == "truncate":
|
||||
optional_params["extra_body"].update({"truncate": v})
|
||||
else:
|
||||
optional_params[k] = v
|
||||
|
||||
if kwargs is not None:
|
||||
# pass kwargs in extra_body
|
||||
|
|
|
@ -117,7 +117,7 @@ class OllamaConfig(BaseConfig):
|
|||
system: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -105,7 +105,7 @@ class OllamaChatConfig(OpenAIGPTConfig):
|
|||
system: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -344,6 +344,10 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
|||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
return model
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
|
|
|
@ -43,23 +43,6 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
|||
"""
|
||||
return messages
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
if stream is not True:
|
||||
return False
|
||||
|
||||
if model is None:
|
||||
return True
|
||||
supported_stream_models = ["o1-mini", "o1-preview"]
|
||||
for supported_model in supported_stream_models:
|
||||
if supported_model in model:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the given model
|
||||
|
|
|
@ -27,6 +27,7 @@ from typing_extensions import overload
|
|||
import litellm
|
||||
from litellm import LlmProviders
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import DEFAULT_MAX_RETRIES
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
|
@ -320,6 +321,17 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _set_dynamic_params_on_client(
|
||||
self,
|
||||
client: Union[OpenAI, AsyncOpenAI],
|
||||
organization: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
):
|
||||
if organization is not None:
|
||||
client.organization = organization
|
||||
if max_retries is not None:
|
||||
client.max_retries = max_retries
|
||||
|
||||
def _get_openai_client(
|
||||
self,
|
||||
is_async: bool,
|
||||
|
@ -327,11 +339,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
max_retries: Optional[int] = 2,
|
||||
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||
):
|
||||
args = locals()
|
||||
if client is None:
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(
|
||||
|
@ -364,7 +375,6 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
organization=organization,
|
||||
)
|
||||
else:
|
||||
|
||||
_new_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
|
@ -383,6 +393,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
return _new_client
|
||||
|
||||
else:
|
||||
self._set_dynamic_params_on_client(
|
||||
client=client,
|
||||
organization=organization,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
return client
|
||||
|
||||
@track_llm_api_timing()
|
||||
|
|
|
@ -20,3 +20,23 @@ class PerplexityChatConfig(OpenAIGPTConfig):
|
|||
or get_secret_str("PERPLEXITY_API_KEY")
|
||||
)
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Perplexity supports a subset of OpenAI params
|
||||
|
||||
Ref: https://docs.perplexity.ai/api-reference/chat-completions
|
||||
|
||||
Eg. Perplexity does not support tools, tool_choice, function_call, functions, etc.
|
||||
"""
|
||||
return [
|
||||
"frequency_penalty",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p" "max_retries",
|
||||
"extra_headers",
|
||||
]
|
||||
|
|
|
@ -58,7 +58,7 @@ class PetalsConfig(BaseConfig):
|
|||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -59,7 +59,7 @@ class PredibaseConfig(BaseConfig):
|
|||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -73,7 +73,7 @@ class ReplicateConfig(BaseConfig):
|
|||
seed: Optional[int] = None,
|
||||
debug: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -47,7 +47,7 @@ class SagemakerConfig(BaseConfig):
|
|||
temperature: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -29,3 +29,7 @@ class TopazModelInfo(BaseLLMModelInfo):
|
|||
return (
|
||||
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
return model
|
||||
|
|
|
@ -179,7 +179,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -17,7 +17,7 @@ class VertexAIAi21Config:
|
|||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import types
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class VertexAILlama3Config:
|
||||
class VertexAILlama3Config(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
|
||||
|
||||
|
@ -21,7 +21,7 @@ class VertexAILlama3Config:
|
|||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
|
@ -46,8 +46,13 @@ class VertexAILlama3Config:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||
def get_supported_openai_params(self, model: str):
|
||||
supported_params = super().get_supported_openai_params(model=model)
|
||||
try:
|
||||
supported_params.remove("max_retries")
|
||||
except KeyError:
|
||||
pass
|
||||
return supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
|
@ -60,7 +65,7 @@ class VertexAILlama3Config:
|
|||
non_default_params["max_tokens"] = non_default_params.pop(
|
||||
"max_completion_tokens"
|
||||
)
|
||||
return litellm.OpenAIConfig().map_openai_params(
|
||||
return super().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
|
|
|
@ -48,7 +48,7 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
|||
] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue