Merge branch 'main' into CakeCrusher/developer_mode

This commit is contained in:
Sebastian Sosa 2025-02-15 14:55:10 -05:00 committed by GitHub
commit 779179b4da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
317 changed files with 11086 additions and 3134 deletions

View file

@ -72,6 +72,7 @@ jobs:
pip install "jsonschema==4.22.0"
pip install "pytest-xdist==3.6.1"
pip install "websockets==10.4"
pip uninstall posthog -y
- save_cache:
paths:
- ./venv
@ -1517,6 +1518,117 @@ jobs:
- store_test_results:
path: test-results
proxy_multi_instance_tests:
machine:
image: ubuntu-2204:2023.10.1
resource_class: xlarge
working_directory: ~/project
steps:
- checkout
- run:
name: Install Docker CLI (In case it's not already installed)
command: |
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io
- run:
name: Install Python 3.9
command: |
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh --output miniconda.sh
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
conda init bash
source ~/.bashrc
conda create -n myenv python=3.9 -y
conda activate myenv
python --version
- run:
name: Install Dependencies
command: |
pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1"
pip install aiohttp
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-mock==3.12.0"
pip install "pytest-asyncio==0.21.1"
- run:
name: Build Docker image
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
- run:
name: Run Docker container 1
# intentionally give bad redis credentials here
# the OTEL test - should get this as a trace
command: |
docker run -d \
-p 4000:4000 \
-e DATABASE_URL=$PROXY_DATABASE_URL \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e LITELLM_MASTER_KEY="sk-1234" \
-e LITELLM_LICENSE=$LITELLM_LICENSE \
-e USE_DDTRACE=True \
-e DD_API_KEY=$DD_API_KEY \
-e DD_SITE=$DD_SITE \
--name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/multi_instance_simple_config.yaml:/app/config.yaml \
my-app:latest \
--config /app/config.yaml \
--port 4000 \
--detailed_debug \
- run:
name: Run Docker container 2
command: |
docker run -d \
-p 4001:4001 \
-e DATABASE_URL=$PROXY_DATABASE_URL \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e LITELLM_MASTER_KEY="sk-1234" \
-e LITELLM_LICENSE=$LITELLM_LICENSE \
-e USE_DDTRACE=True \
-e DD_API_KEY=$DD_API_KEY \
-e DD_SITE=$DD_SITE \
--name my-app-2 \
-v $(pwd)/litellm/proxy/example_config_yaml/multi_instance_simple_config.yaml:/app/config.yaml \
my-app:latest \
--config /app/config.yaml \
--port 4001 \
--detailed_debug
- run:
name: Install curl and dockerize
command: |
sudo apt-get update
sudo apt-get install -y curl
sudo wget https://github.com/jwilder/dockerize/releases/download/v0.6.1/dockerize-linux-amd64-v0.6.1.tar.gz
sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-v0.6.1.tar.gz
sudo rm dockerize-linux-amd64-v0.6.1.tar.gz
- run:
name: Start outputting logs
command: docker logs -f my-app
background: true
- run:
name: Wait for instance 1 to be ready
command: dockerize -wait http://localhost:4000 -timeout 5m
- run:
name: Wait for instance 2 to be ready
command: dockerize -wait http://localhost:4001 -timeout 5m
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/multi_instance_e2e_tests -x --junitxml=test-results/junit.xml --durations=5
no_output_timeout:
120m
# Clean up first container
# Store test results
- store_test_results:
path: test-results
proxy_store_model_in_db_tests:
machine:
image: ubuntu-2204:2023.10.1
@ -1552,6 +1664,7 @@ jobs:
pip install "pytest-retry==1.6.3"
pip install "pytest-mock==3.12.0"
pip install "pytest-asyncio==0.21.1"
pip install "assemblyai==0.37.0"
- run:
name: Build Docker image
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
@ -1904,7 +2017,7 @@ jobs:
circleci step halt
fi
- run:
name: Trigger Github Action for new Docker Container + Trigger Stable Release Testing
name: Trigger Github Action for new Docker Container + Trigger Load Testing
command: |
echo "Install TOML package."
python3 -m pip install toml
@ -1914,9 +2027,9 @@ jobs:
-H "Accept: application/vnd.github.v3+json" \
-H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}\", \"commit_hash\":\"$CIRCLE_SHA1\"}}"
echo "triggering stable release server for version ${VERSION} and commit ${CIRCLE_SHA1}"
curl -X POST "https://proxyloadtester-production.up.railway.app/start/load/test?version=${VERSION}&commit_hash=${CIRCLE_SHA1}"
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}-nightly\", \"commit_hash\":\"$CIRCLE_SHA1\"}}"
echo "triggering load testing server for version ${VERSION} and commit ${CIRCLE_SHA1}"
curl -X POST "https://proxyloadtester-production.up.railway.app/start/load/test?version=${VERSION}&commit_hash=${CIRCLE_SHA1}&release_type=nightly"
e2e_ui_testing:
machine:
@ -2171,6 +2284,12 @@ workflows:
only:
- main
- /litellm_.*/
- proxy_multi_instance_tests:
filters:
branches:
only:
- main
- /litellm_.*/
- proxy_store_model_in_db_tests:
filters:
branches:
@ -2302,6 +2421,7 @@ workflows:
- installing_litellm_on_python
- installing_litellm_on_python_3_13
- proxy_logging_guardrails_model_info_tests
- proxy_multi_instance_tests
- proxy_store_model_in_db_tests
- proxy_build_from_pip_tests
- proxy_pass_through_endpoint_tests

View file

@ -52,6 +52,39 @@ def interpret_results(csv_file):
return markdown_table
def _get_docker_run_command_stable_release(release_version):
return f"""
\n\n
## Docker Run LiteLLM Proxy
```
docker run \\
-e STORE_MODEL_IN_DB=True \\
-p 4000:4000 \\
ghcr.io/berriai/litellm_stable_release_branch-{release_version}
"""
def _get_docker_run_command(release_version):
return f"""
\n\n
## Docker Run LiteLLM Proxy
```
docker run \\
-e STORE_MODEL_IN_DB=True \\
-p 4000:4000 \\
ghcr.io/berriai/litellm:main-{release_version}
"""
def get_docker_run_command(release_version):
if "stable" in release_version:
return _get_docker_run_command_stable_release(release_version)
else:
return _get_docker_run_command(release_version)
if __name__ == "__main__":
csv_file = "load_test_stats.csv" # Change this to the path of your CSV file
markdown_table = interpret_results(csv_file)
@ -79,17 +112,7 @@ if __name__ == "__main__":
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
existing_release_body = latest_release.body[:start_index]
docker_run_command = f"""
\n\n
## Docker Run LiteLLM Proxy
```
docker run \\
-e STORE_MODEL_IN_DB=True \\
-p 4000:4000 \\
ghcr.io/berriai/litellm:main-{release_version}
```
"""
docker_run_command = get_docker_run_command(release_version)
print("docker run command: ", docker_run_command)
new_release_body = (

View file

@ -0,0 +1,172 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "4FbDOmcj2VkM"
},
"source": [
"## Use LiteLLM with Arize\n",
"https://docs.litellm.ai/docs/observability/arize_integration\n",
"\n",
"This method uses the litellm proxy to send the data to Arize. The callback is set in the litellm config below, instead of using OpenInference tracing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "21W8Woog26Ns"
},
"source": [
"## Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "xrjKLBxhxu2L"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: litellm in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (1.54.1)\n",
"Requirement already satisfied: aiohttp in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (3.11.10)\n",
"Requirement already satisfied: click in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (8.1.7)\n",
"Requirement already satisfied: httpx<0.28.0,>=0.23.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.27.2)\n",
"Requirement already satisfied: importlib-metadata>=6.8.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (8.5.0)\n",
"Requirement already satisfied: jinja2<4.0.0,>=3.1.2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (3.1.4)\n",
"Requirement already satisfied: jsonschema<5.0.0,>=4.22.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (4.23.0)\n",
"Requirement already satisfied: openai>=1.55.3 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (1.57.1)\n",
"Requirement already satisfied: pydantic<3.0.0,>=2.0.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (2.10.3)\n",
"Requirement already satisfied: python-dotenv>=0.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (1.0.1)\n",
"Requirement already satisfied: requests<3.0.0,>=2.31.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (2.32.3)\n",
"Requirement already satisfied: tiktoken>=0.7.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.7.0)\n",
"Requirement already satisfied: tokenizers in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from litellm) (0.21.0)\n",
"Requirement already satisfied: anyio in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (4.7.0)\n",
"Requirement already satisfied: certifi in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (2024.8.30)\n",
"Requirement already satisfied: httpcore==1.* in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (1.0.7)\n",
"Requirement already satisfied: idna in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (3.10)\n",
"Requirement already satisfied: sniffio in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpx<0.28.0,>=0.23.0->litellm) (1.3.1)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.23.0->litellm) (0.14.0)\n",
"Requirement already satisfied: zipp>=3.20 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from importlib-metadata>=6.8.0->litellm) (3.21.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jinja2<4.0.0,>=3.1.2->litellm) (3.0.2)\n",
"Requirement already satisfied: attrs>=22.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (24.2.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (2024.10.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (0.35.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from jsonschema<5.0.0,>=4.22.0->litellm) (0.22.3)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (1.9.0)\n",
"Requirement already satisfied: jiter<1,>=0.4.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (0.6.1)\n",
"Requirement already satisfied: tqdm>4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (4.67.1)\n",
"Requirement already satisfied: typing-extensions<5,>=4.11 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from openai>=1.55.3->litellm) (4.12.2)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from pydantic<3.0.0,>=2.0.0->litellm) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.27.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from pydantic<3.0.0,>=2.0.0->litellm) (2.27.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.31.0->litellm) (3.4.0)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.31.0->litellm) (2.0.7)\n",
"Requirement already satisfied: regex>=2022.1.18 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from tiktoken>=0.7.0->litellm) (2024.11.6)\n",
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (2.4.4)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.3.1)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.5.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (6.1.0)\n",
"Requirement already satisfied: propcache>=0.2.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (0.2.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from aiohttp->litellm) (1.18.3)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from tokenizers->litellm) (0.26.5)\n",
"Requirement already satisfied: filelock in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (3.16.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (2024.10.0)\n",
"Requirement already satisfied: packaging>=20.9 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (24.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Users/ericxiao/Documents/arize/.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers->litellm) (6.0.2)\n"
]
}
],
"source": [
"!pip install litellm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jHEu-TjZ29PJ"
},
"source": [
"## Set Env Variables"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "QWd9rTysxsWO"
},
"outputs": [],
"source": [
"import litellm\n",
"import os\n",
"from getpass import getpass\n",
"\n",
"os.environ[\"ARIZE_SPACE_KEY\"] = getpass(\"Enter your Arize space key: \")\n",
"os.environ[\"ARIZE_API_KEY\"] = getpass(\"Enter your Arize API key: \")\n",
"os.environ['OPENAI_API_KEY']= getpass(\"Enter your OpenAI API key: \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's run a completion call and see the traces in Arize"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello! Nice to meet you, OpenAI. How can I assist you today?\n"
]
}
],
"source": [
"# set arize as a callback, litellm will send the data to arize\n",
"litellm.callbacks = [\"arize\"]\n",
" \n",
"# openai call\n",
"response = litellm.completion(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Hi 👋 - i'm openai\"}\n",
" ]\n",
")\n",
"print(response.choices[0].message.content)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -0,0 +1,252 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LLM Ops Stack - LiteLLM Proxy + Langfuse \n",
"\n",
"This notebook demonstrates how to use LiteLLM Proxy with Langfuse \n",
"- Use LiteLLM Proxy for calling 100+ LLMs in OpenAI format\n",
"- Use Langfuse for viewing request / response traces \n",
"\n",
"\n",
"In this notebook we will setup LiteLLM Proxy to make requests to OpenAI, Anthropic, Bedrock and automatically log traces to Langfuse."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup LiteLLM Proxy\n",
"\n",
"### 1.1 Define .env variables \n",
"Define .env variables on the container that litellm proxy is running on.\n",
"```bash\n",
"## LLM API Keys\n",
"OPENAI_API_KEY=sk-proj-1234567890\n",
"ANTHROPIC_API_KEY=sk-ant-api03-1234567890\n",
"AWS_ACCESS_KEY_ID=1234567890\n",
"AWS_SECRET_ACCESS_KEY=1234567890\n",
"\n",
"## Langfuse Logging \n",
"LANGFUSE_PUBLIC_KEY=\"pk-lf-xxxx9\"\n",
"LANGFUSE_SECRET_KEY=\"sk-lf-xxxx9\"\n",
"LANGFUSE_HOST=\"https://us.cloud.langfuse.com\"\n",
"```\n",
"\n",
"\n",
"### 1.1 Setup LiteLLM Proxy Config yaml \n",
"```yaml\n",
"model_list:\n",
" - model_name: gpt-4o\n",
" litellm_params:\n",
" model: openai/gpt-4o\n",
" api_key: os.environ/OPENAI_API_KEY\n",
" - model_name: claude-3-5-sonnet-20241022\n",
" litellm_params:\n",
" model: anthropic/claude-3-5-sonnet-20241022\n",
" api_key: os.environ/ANTHROPIC_API_KEY\n",
" - model_name: us.amazon.nova-micro-v1:0\n",
" litellm_params:\n",
" model: bedrock/us.amazon.nova-micro-v1:0\n",
" aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID\n",
" aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY\n",
"\n",
"litellm_settings:\n",
" callbacks: [\"langfuse\"]\n",
"\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Make LLM Requests to LiteLLM Proxy\n",
"\n",
"Now we will make our first LLM request to LiteLLM Proxy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1 Setup Client Side Variables to point to LiteLLM Proxy\n",
"Set `LITELLM_PROXY_BASE_URL` to the base url of the LiteLLM Proxy and `LITELLM_VIRTUAL_KEY` to the virtual key you want to use for Authentication to LiteLLM Proxy. (Note: In this initial setup you can)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"\n",
"LITELLM_PROXY_BASE_URL=\"http://0.0.0.0:4000\"\n",
"LITELLM_VIRTUAL_KEY=\"sk-oXXRa1xxxxxxxxxxx\""
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletion(id='chatcmpl-B0sq6QkOKNMJ0dwP3x7OoMqk1jZcI', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Langfuse is a platform designed to monitor, observe, and troubleshoot AI and large language model (LLM) applications. It provides features that help developers gain insights into how their AI systems are performing, make debugging easier, and optimize the deployment of models. Langfuse allows for tracking of model interactions, collecting telemetry, and visualizing data, which is crucial for understanding the behavior of AI models in production environments. This kind of tool is particularly useful for developers working with language models who need to ensure reliability and efficiency in their applications.', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None))], created=1739550502, model='gpt-4o-2024-08-06', object='chat.completion', service_tier='default', system_fingerprint='fp_523b9b6e5f', usage=CompletionUsage(completion_tokens=109, prompt_tokens=13, total_tokens=122, completion_tokens_details=CompletionTokensDetails(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0), prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0)))"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import openai\n",
"client = openai.OpenAI(\n",
" api_key=LITELLM_VIRTUAL_KEY,\n",
" base_url=LITELLM_PROXY_BASE_URL\n",
")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"gpt-4o\",\n",
" messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"what is Langfuse?\"\n",
" }\n",
" ],\n",
")\n",
"\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.3 View Traces on Langfuse\n",
"LiteLLM will send the request / response, model, tokens (input + output), cost to Langfuse.\n",
"\n",
"![image_description](litellm_proxy_langfuse.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.4 Call Anthropic, Bedrock models \n",
"\n",
"Now we can call `us.amazon.nova-micro-v1:0` and `claude-3-5-sonnet-20241022` models defined on your config.yaml both in the OpenAI request / response format."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletion(id='chatcmpl-7756e509-e61f-4f5e-b5ae-b7a41013522a', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=\"Langfuse is an observability tool designed specifically for machine learning models and applications built with natural language processing (NLP) and large language models (LLMs). It focuses on providing detailed insights into how these models perform in real-world scenarios. Here are some key features and purposes of Langfuse:\\n\\n1. **Real-time Monitoring**: Langfuse allows developers to monitor the performance of their NLP and LLM applications in real time. This includes tracking the inputs and outputs of the models, as well as any errors or issues that arise during operation.\\n\\n2. **Error Tracking**: It helps in identifying and tracking errors in the models' outputs. By analyzing incorrect or unexpected responses, developers can pinpoint where and why errors occur, facilitating more effective debugging and improvement.\\n\\n3. **Performance Metrics**: Langfuse provides various performance metrics, such as latency, throughput, and error rates. These metrics help developers understand how well their models are performing under different conditions and workloads.\\n\\n4. **Traceability**: It offers detailed traceability of requests and responses, allowing developers to follow the path of a request through the system and see how it is processed by the model at each step.\\n\\n5. **User Feedback Integration**: Langfuse can integrate user feedback to provide context for model outputs. This helps in understanding how real users are interacting with the model and how its outputs align with user expectations.\\n\\n6. **Customizable Dashboards**: Users can create custom dashboards to visualize the data collected by Langfuse. These dashboards can be tailored to highlight the most important metrics and insights for a specific application or team.\\n\\n7. **Alerting and Notifications**: It can set up alerts for specific conditions or errors, notifying developers when something goes wrong or when performance metrics fall outside of acceptable ranges.\\n\\nBy providing comprehensive observability for NLP and LLM applications, Langfuse helps developers to build more reliable, accurate, and user-friendly models and services.\", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None))], created=1739554005, model='us.amazon.nova-micro-v1:0', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=380, prompt_tokens=5, total_tokens=385, completion_tokens_details=None, prompt_tokens_details=None))"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import openai\n",
"client = openai.OpenAI(\n",
" api_key=LITELLM_VIRTUAL_KEY,\n",
" base_url=LITELLM_PROXY_BASE_URL\n",
")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"us.amazon.nova-micro-v1:0\",\n",
" messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"what is Langfuse?\"\n",
" }\n",
" ],\n",
")\n",
"\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Advanced - Set Langfuse Trace ID, Tags, Metadata \n",
"\n",
"Here is an example of how you can set Langfuse specific params on your client side request. See full list of supported langfuse params [here](https://docs.litellm.ai/docs/observability/langfuse_integration)\n",
"\n",
"You can view the logged trace of this request [here](https://us.cloud.langfuse.com/project/clvlhdfat0007vwb74m9lvfvi/traces/567890?timestamp=2025-02-14T17%3A30%3A26.709Z)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletion(id='chatcmpl-789babd5-c064-4939-9093-46e4cd2e208a', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=\"Langfuse is an observability platform designed specifically for monitoring and improving the performance of natural language processing (NLP) models and applications. It provides developers with tools to track, analyze, and optimize how their language models interact with users and handle natural language inputs.\\n\\nHere are some key features and benefits of Langfuse:\\n\\n1. **Real-Time Monitoring**: Langfuse allows developers to monitor their NLP applications in real time. This includes tracking user interactions, model responses, and overall performance metrics.\\n\\n2. **Error Tracking**: It helps in identifying and tracking errors in the model's responses. This can include incorrect, irrelevant, or unsafe outputs.\\n\\n3. **User Feedback Integration**: Langfuse enables the collection of user feedback directly within the platform. This feedback can be used to identify areas for improvement in the model's performance.\\n\\n4. **Performance Metrics**: The platform provides detailed metrics and analytics on model performance, including latency, throughput, and accuracy.\\n\\n5. **Alerts and Notifications**: Developers can set up alerts to notify them of any significant issues or anomalies in model performance.\\n\\n6. **Debugging Tools**: Langfuse offers tools to help developers debug and refine their models by providing insights into how the model processes different types of inputs.\\n\\n7. **Integration with Development Workflows**: It integrates seamlessly with various development environments and CI/CD pipelines, making it easier to incorporate observability into the development process.\\n\\n8. **Customizable Dashboards**: Users can create custom dashboards to visualize the data in a way that best suits their needs.\\n\\nLangfuse aims to help developers build more reliable, accurate, and user-friendly NLP applications by providing them with the tools to observe and improve how their models perform in real-world scenarios.\", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None))], created=1739554281, model='us.amazon.nova-micro-v1:0', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=346, prompt_tokens=5, total_tokens=351, completion_tokens_details=None, prompt_tokens_details=None))"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import openai\n",
"client = openai.OpenAI(\n",
" api_key=LITELLM_VIRTUAL_KEY,\n",
" base_url=LITELLM_PROXY_BASE_URL\n",
")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"us.amazon.nova-micro-v1:0\",\n",
" messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"what is Langfuse?\"\n",
" }\n",
" ],\n",
" extra_body={\n",
" \"metadata\": {\n",
" \"generation_id\": \"1234567890\",\n",
" \"trace_id\": \"567890\",\n",
" \"trace_user_id\": \"user_1234567890\",\n",
" \"tags\": [\"tag1\", \"tag2\"]\n",
" }\n",
" }\n",
")\n",
"\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## "
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 308 KiB

View file

@ -11,9 +11,7 @@ FROM $LITELLM_BUILD_IMAGE AS builder
WORKDIR /app
# Install build dependencies
RUN apk update && \
apk add --no-cache gcc python3-dev musl-dev && \
rm -rf /var/cache/apk/*
RUN apk add --no-cache gcc python3-dev musl-dev
RUN pip install --upgrade pip && \
pip install build

View file

@ -1,5 +1,5 @@
# Local Debugging
There's 2 ways to do local debugging - `litellm.set_verbose=True` and by passing in a custom function `completion(...logger_fn=<your_local_function>)`. Warning: Make sure to not use `set_verbose` in production. It logs API keys, which might end up in log files.
There's 2 ways to do local debugging - `litellm._turn_on_debug()` and by passing in a custom function `completion(...logger_fn=<your_local_function>)`. Warning: Make sure to not use `_turn_on_debug()` in production. It logs API keys, which might end up in log files.
## Set Verbose
@ -8,7 +8,7 @@ This is good for getting print statements for everything litellm is doing.
import litellm
from litellm import completion
litellm.set_verbose=True # 👈 this is the 1-line change you need to make
litellm._turn_on_debug() # 👈 this is the 1-line change you need to make
## set ENV variables
os.environ["OPENAI_API_KEY"] = "openai key"

View file

@ -19,6 +19,7 @@ Make an account on [Arize AI](https://app.arize.com/auth/login)
## Quick Start
Use just 2 lines of code, to instantly log your responses **across all providers** with arize
You can also use the instrumentor option instead of the callback, which you can find [here](https://docs.arize.com/arize/llm-tracing/tracing-integrations-auto/litellm).
```python
litellm.callbacks = ["arize"]
@ -28,7 +29,7 @@ import litellm
import os
os.environ["ARIZE_SPACE_KEY"] = ""
os.environ["ARIZE_API_KEY"] = "" # defaults to litellm-completion
os.environ["ARIZE_API_KEY"] = ""
# LLM API Keys
os.environ['OPENAI_API_KEY']=""

View file

@ -78,7 +78,7 @@ Following are the allowed fields in metadata, their types, and their description
* `context: Optional[Union[dict, str]]` - This is the context used as information for the prompt. For RAG applications, this is the "retrieved" data. You may log context as a string or as an object (dictionary).
* `expected_response: Optional[str]` - This is the reference response to compare against for evaluation purposes. This is useful for segmenting inference calls by expected response.
* `user_query: Optional[str]` - This is the user's query. For conversational applications, this is the user's last message.
* `custom_attributes: Optional[dict]` - This is a dictionary of custom attributes. This is useful for additional information about the inference.
## Using a self hosted deployment of Athina

View file

@ -0,0 +1,75 @@
import Image from '@theme/IdealImage';
# Phoenix OSS
Open source tracing and evaluation platform
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
## Pre-Requisites
Make an account on [Phoenix OSS](https://phoenix.arize.com)
OR self-host your own instance of [Phoenix](https://docs.arize.com/phoenix/deployment)
## Quick Start
Use just 2 lines of code, to instantly log your responses **across all providers** with Phoenix
You can also use the instrumentor option instead of the callback, which you can find [here](https://docs.arize.com/phoenix/tracing/integrations-tracing/litellm).
```python
litellm.callbacks = ["arize_phoenix"]
```
```python
import litellm
import os
os.environ["PHOENIX_API_KEY"] = "" # Necessary only using Phoenix Cloud
os.environ["PHOENIX_COLLECTOR_HTTP_ENDPOINT"] = "" # The URL of your Phoenix OSS instance
# This defaults to https://app.phoenix.arize.com/v1/traces for Phoenix Cloud
# LLM API Keys
os.environ['OPENAI_API_KEY']=""
# set arize as a callback, litellm will send the data to arize
litellm.callbacks = ["phoenix"]
# openai call
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Hi 👋 - i'm openai"}
]
)
```
### Using with LiteLLM Proxy
```yaml
model_list:
- model_name: gpt-4o
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
callbacks: ["arize_phoenix"]
environment_variables:
PHOENIX_API_KEY: "d0*****"
PHOENIX_COLLECTOR_ENDPOINT: "https://app.phoenix.arize.com/v1/traces" # OPTIONAL, for setting the GRPC endpoint
PHOENIX_COLLECTOR_HTTP_ENDPOINT: "https://app.phoenix.arize.com/v1/traces" # OPTIONAL, for setting the HTTP endpoint
```
## Support & Talk to Founders
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai

View file

@ -12,6 +12,9 @@ Supports **ALL** Assembly AI Endpoints
[**See All Assembly AI Endpoints**](https://www.assemblyai.com/docs/api-reference)
<iframe width="840" height="500" src="https://www.loom.com/embed/aac3f4d74592448992254bfa79b9f62d?sid=267cd0ab-d92b-42fa-b97a-9f385ef8930c" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
## Quick Start
Let's call the Assembly AI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts)
@ -35,6 +38,8 @@ litellm
Let's call the Assembly AI `/v2/transcripts` endpoint
```python
import assemblyai as aai
LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key>
LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/assemblyai" # <your-proxy-base-url>/assemblyai
@ -53,3 +58,28 @@ print(transcript)
print(transcript.id)
```
## Calling Assembly AI EU endpoints
If you want to send your request to the Assembly AI EU endpoint, you can do so by setting the `LITELLM_PROXY_BASE_URL` to `<your-proxy-base-url>/eu.assemblyai`
```python
import assemblyai as aai
LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key>
LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/eu.assemblyai" # <your-proxy-base-url>/eu.assemblyai
aai.settings.api_key = f"Bearer {LITELLM_VIRTUAL_KEY}"
aai.settings.base_url = LITELLM_PROXY_BASE_URL
# URL of the file to transcribe
FILE_URL = "https://assembly.ai/wildfires.mp3"
# You can also transcribe a local file by passing in a file path
# FILE_URL = './path/to/file.mp3'
transcriber = aai.Transcriber()
transcript = transcriber.transcribe(FILE_URL)
print(transcript)
print(transcript.id)
```

View file

@ -987,6 +987,106 @@ curl http://0.0.0.0:4000/v1/chat/completions \
</TabItem>
</Tabs>
## [BETA] Citations API
Pass `citations: {"enabled": true}` to Anthropic, to get citations on your document responses.
Note: This interface is in BETA. If you have feedback on how citations should be returned, please [tell us here](https://github.com/BerriAI/litellm/issues/7970#issuecomment-2644437943)
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
resp = completion(
model="claude-3-5-sonnet-20241022",
messages=[
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"title": "My Document",
"context": "This is a trustworthy document.",
"citations": {"enabled": True},
},
{
"type": "text",
"text": "What color is the grass and sky?",
},
],
}
],
)
citations = resp.choices[0].message.provider_specific_fields["citations"]
assert citations is not None
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-5-sonnet-20241022
api_key: os.environ/ANTHROPIC_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "anthropic-claude",
"messages": [
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"title": "My Document",
"context": "This is a trustworthy document.",
"citations": {"enabled": True},
},
{
"type": "text",
"text": "What color is the grass and sky?",
},
],
}
]
}'
```
</TabItem>
</Tabs>
## Usage - passing 'user_id' to Anthropic
LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.

View file

@ -7,7 +7,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
| Property | Details |
|-------|-------|
| Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). |
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#bedrock-imported-models-deepseek) |
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) |
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
@ -1277,13 +1277,83 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
https://some-api-url/models
```
## Bedrock Imported Models (Deepseek)
## Bedrock Imported Models (Deepseek, Deepseek R1)
### Deepseek R1
This is a separate route, as the chat template is different.
| Property | Details |
|----------|---------|
| Provider Route | `bedrock/deepseek_r1/{model_arn}` |
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
response = completion(
model="bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/deepseek_r1/{your-model-arn}
messages=[{"role": "user", "content": "Tell me a joke"}],
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: DeepSeek-R1-Distill-Llama-70B
litellm_params:
model: bedrock/deepseek_r1/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
### Deepseek (not R1)
| Property | Details |
|----------|---------|
| Provider Route | `bedrock/llama/{model_arn}` |
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec

View file

@ -688,7 +688,9 @@ response = litellm.completion(
|-----------------------|--------------------------------------------------------|--------------------------------|
| gemini-pro | `completion(model='gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro-latest | `completion(model='gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-pro-vision | `completion(model='gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-2.0-flash | `completion(model='gemini/gemini-2.0-flash', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-2.0-flash-exp | `completion(model='gemini/gemini-2.0-flash-exp', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-2.0-flash-lite-preview-02-05 | `completion(model='gemini/gemini-2.0-flash-lite-preview-02-05', messages)` | `os.environ['GEMINI_API_KEY']` |

View file

@ -64,71 +64,7 @@ All models listed here https://docs.perplexity.ai/docs/model-cards are supported
## Return citations
Perplexity supports returning citations via `return_citations=True`. [Perplexity Docs](https://docs.perplexity.ai/reference/post_chat_completions). Note: Perplexity has this feature in **closed beta**, so you need them to grant you access to get citations from their API.
If perplexity returns citations, LiteLLM will pass it straight through.
:::info
For passing more provider-specific, [go here](../completion/provider_specific_params.md)
For more information about passing provider-specific parameters, [go here](../completion/provider_specific_params.md)
:::
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ['PERPLEXITYAI_API_KEY'] = ""
response = completion(
model="perplexity/mistral-7b-instruct",
messages=messages,
return_citations=True
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add perplexity to config.yaml
```yaml
model_list:
- model_name: "perplexity-model"
litellm_params:
model: "llama-3.1-sonar-small-128k-online"
api_key: os.environ/PERPLEXITY_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "perplexity-model",
"messages": [
{
"role": "user",
"content": "Who won the world cup in 2022?"
}
],
"return_citations": true
}'
```
[**Call w/ OpenAI SDK, Langchain, Instructor, etc.**](../proxy/user_keys.md#chatcompletions)
</TabItem>
</Tabs>

View file

@ -488,12 +488,12 @@ router_settings:
| SLACK_DAILY_REPORT_FREQUENCY | Frequency of daily Slack reports (e.g., daily, weekly)
| SLACK_WEBHOOK_URL | Webhook URL for Slack integration
| SMTP_HOST | Hostname for the SMTP server
| SMTP_PASSWORD | Password for SMTP authentication
| SMTP_PASSWORD | Password for SMTP authentication (do not set if SMTP does not require auth)
| SMTP_PORT | Port number for SMTP server
| SMTP_SENDER_EMAIL | Email address used as the sender in SMTP transactions
| SMTP_SENDER_LOGO | Logo used in emails sent via SMTP
| SMTP_TLS | Flag to enable or disable TLS for SMTP connections
| SMTP_USERNAME | Username for SMTP authentication
| SMTP_USERNAME | Username for SMTP authentication (do not set if SMTP does not require auth)
| SPEND_LOGS_URL | URL for retrieving spend logs
| SSL_CERTIFICATE | Path to the SSL certificate file
| SSL_VERIFY | Flag to enable or disable SSL certificate verification

View file

@ -37,7 +37,7 @@ guardrails:
- guardrail_name: aim-protected-app
litellm_params:
guardrail: aim
mode: pre_call
mode: pre_call # 'during_call' is also available
api_key: os.environ/AIM_API_KEY
api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost
```

View file

@ -166,7 +166,7 @@ response = client.chat.completions.create(
{"role": "user", "content": "what color is red"}
],
logit_bias={12481: 100},
timeout=1
extra_body={"timeout": 1} # 👈 KEY CHANGE
)
print(response)

View file

@ -163,10 +163,12 @@ scope: "litellm-proxy-admin ..."
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_jwtauth:
user_id_jwt_field: "sub"
team_ids_jwt_field: "groups"
user_id_upsert: true # add user_id to the db if they don't exist
enforce_team_based_model_access: true # don't allow users to access models unless the team has access
```
This is assuming your token looks like this:
@ -352,11 +354,11 @@ environment_variables:
### Example Token
```
```bash
{
"aud": "api://LiteLLM_Proxy",
"oid": "eec236bd-0135-4b28-9354-8fc4032d543e",
"roles": ["litellm.api.consumer"]
"roles": ["litellm.api.consumer"]
}
```
@ -370,4 +372,68 @@ Supported internal roles:
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
## [BETA] Control Model Access with Scopes
Control which models a JWT can access. Set `enforce_scope_based_access: true` to enforce scope-based access control.
### 1. Setup config.yaml with scope mappings.
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-5-sonnet
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: gpt-3.5-turbo-testing
litellm_params:
model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
team_id_jwt_field: "client_id" # 👈 set the field in the JWT token that contains the team id
team_id_upsert: true # 👈 upsert the team to db, if team id is not found in db
scope_mappings:
- scope: litellm.api.consumer
models: ["anthropic-claude"]
- scope: litellm.api.gpt_3_5_turbo
models: ["gpt-3.5-turbo-testing"]
enforce_scope_based_access: true # 👈 enforce scope-based access control
enforce_rbac: true # 👈 enforces only a Team/User/ProxyAdmin can access the proxy.
```
#### Scope Mapping Spec
- `scope`: The scope to be used for the JWT token.
- `models`: The models that the JWT token can access. Value is the `model_name` in `model_list`. Note: Wildcard routes are not currently supported.
### 2. Create a JWT with the correct scopes.
Expected Token:
```bash
{
"scope": ["litellm.api.consumer", "litellm.api.gpt_3_5_turbo"] # can be a list or a space-separated string
}
```
### 3. Test the flow.
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer eyJhbGci...' \
-d '{
"model": "gpt-3.5-turbo-testing",
"messages": [
{
"role": "user",
"content": "Hey, how'\''s it going 1234?"
}
]
}'
```

View file

@ -52,6 +52,7 @@ from litellm.constants import (
open_ai_embedding_models,
cohere_embedding_models,
bedrock_embedding_models,
known_tokenizer_config,
)
from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import (
@ -360,7 +361,15 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-2-90b-instruct-v1:0",
]
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
"cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21"
"cohere",
"anthropic",
"mistral",
"amazon",
"meta",
"llama",
"ai21",
"nova",
"deepseek_r1",
]
####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = []
@ -863,6 +872,9 @@ from .llms.bedrock.common_utils import (
from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import (
AmazonAI21Config,
)
from .llms.bedrock.chat.invoke_transformations.amazon_nova_transformation import (
AmazonInvokeNovaConfig,
)
from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import (
AmazonAnthropicConfig,
)

View file

@ -183,7 +183,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
)
verbose_logger.debug(
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
"init_redis_cluster: startup nodes are being initialized."
)
from redis.cluster import ClusterNode
@ -266,7 +266,9 @@ def get_redis_client(**env_overrides):
return redis.Redis(**redis_kwargs)
def get_redis_async_client(**env_overrides) -> async_redis.Redis:
def get_redis_async_client(
**env_overrides,
) -> async_redis.Redis:
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)

View file

@ -4,5 +4,6 @@ from .dual_cache import DualCache
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache

View file

@ -41,6 +41,7 @@ from .dual_cache import DualCache # noqa
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache
@ -158,14 +159,23 @@ class Cache:
None. Cache is set as a litellm param
"""
if type == LiteLLMCacheType.REDIS:
self.cache: BaseCache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
if redis_startup_nodes:
self.cache: BaseCache = RedisClusterCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
else:
self.cache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
**kwargs,
)
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
self.cache = RedisSemanticCache(
host=host,

View file

@ -14,7 +14,7 @@ import inspect
import json
import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import litellm
from litellm._logging import print_verbose, verbose_logger
@ -26,15 +26,20 @@ from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline
from redis.asyncio.cluster import ClusterPipeline
pipeline = Pipeline
cluster_pipeline = ClusterPipeline
async_redis_client = Redis
async_redis_cluster_client = RedisCluster
Span = _Span
else:
pipeline = Any
cluster_pipeline = Any
async_redis_client = Any
async_redis_cluster_client = Any
Span = Any
@ -122,7 +127,9 @@ class RedisCache(BaseCache):
else:
super().__init__() # defaults to 60s
def init_async_client(self):
def init_async_client(
self,
) -> Union[async_redis_client, async_redis_cluster_client]:
from .._redis import get_redis_async_client
return get_redis_async_client(
@ -345,8 +352,14 @@ class RedisCache(BaseCache):
)
async def _pipeline_helper(
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
self,
pipe: Union[pipeline, cluster_pipeline],
cache_list: List[Tuple[Any, Any]],
ttl: Optional[float],
) -> List:
"""
Helper function for executing a pipeline of set operations on Redis
"""
ttl = self.get_ttl(ttl=ttl)
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
@ -359,7 +372,11 @@ class RedisCache(BaseCache):
_td: Optional[timedelta] = None
if ttl is not None:
_td = timedelta(seconds=ttl)
pipe.set(cache_key, json_cache_value, ex=_td)
pipe.set( # type: ignore
name=cache_key,
value=json_cache_value,
ex=_td,
)
# Execute the pipeline and return the results.
results = await pipe.execute()
return results
@ -373,9 +390,8 @@ class RedisCache(BaseCache):
# don't waste a network request if there's nothing to set
if len(cache_list) == 0:
return
from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore
_redis_client = self.init_async_client()
start_time = time.time()
print_verbose(
@ -384,7 +400,7 @@ class RedisCache(BaseCache):
cache_value: Any = None
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
async with redis_client.pipeline(transaction=False) as pipe:
results = await self._pipeline_helper(pipe, cache_list, ttl)
print_verbose(f"pipeline results: {results}")
@ -730,7 +746,8 @@ class RedisCache(BaseCache):
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
_redis_client: Any = self.init_async_client()
key_value_dict = {}
start_time = time.time()
try:
@ -822,7 +839,8 @@ class RedisCache(BaseCache):
raise e
async def ping(self) -> bool:
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
_redis_client: Any = self.init_async_client()
start_time = time.time()
async with _redis_client as redis_client:
print_verbose("Pinging Async Redis Cache")
@ -858,7 +876,8 @@ class RedisCache(BaseCache):
raise e
async def delete_cache_keys(self, keys):
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
_redis_client: Any = self.init_async_client()
# keys is a list, unpack it so it gets passed as individual elements to delete
async with _redis_client as redis_client:
await redis_client.delete(*keys)
@ -881,7 +900,8 @@ class RedisCache(BaseCache):
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
async def async_delete_cache(self, key: str):
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
_redis_client: Any = self.init_async_client()
# keys is str
async with _redis_client as redis_client:
await redis_client.delete(key)
@ -936,7 +956,7 @@ class RedisCache(BaseCache):
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
async with redis_client.pipeline(transaction=False) as pipe:
results = await self._pipeline_increment_helper(
pipe, increment_list
)
@ -991,7 +1011,8 @@ class RedisCache(BaseCache):
Redis ref: https://redis.io/docs/latest/commands/ttl/
"""
try:
_redis_client = await self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
_redis_client: Any = self.init_async_client()
async with _redis_client as redis_client:
ttl = await redis_client.ttl(key)
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist

View file

@ -0,0 +1,44 @@
"""
Redis Cluster Cache implementation
Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
"""
from typing import TYPE_CHECKING, Any, Optional
from litellm.caching.redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline
pipeline = Pipeline
async_redis_client = Redis
Span = _Span
else:
pipeline = Any
async_redis_client = Any
Span = Any
class RedisClusterCache(RedisCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.redis_cluster_client: Optional[RedisCluster] = None
def init_async_client(self):
from redis.asyncio import RedisCluster
from .._redis import get_redis_async_client
if self.redis_cluster_client:
return self.redis_cluster_client
_redis_client = get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
if isinstance(_redis_client, RedisCluster):
self.redis_cluster_client = _redis_client
return _redis_client

View file

@ -335,6 +335,63 @@ bedrock_embedding_models: List = [
"cohere.embed-multilingual-v3",
]
known_tokenizer_config = {
"mistralai/Mistral-7B-Instruct-v0.1": {
"tokenizer": {
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"bos_token": "<s>",
"eos_token": "</s>",
},
"status": "success",
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"tokenizer": {
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
"bos_token": "<|begin_of_text|>",
"eos_token": "",
},
"status": "success",
},
"deepseek-r1/deepseek-r1-7b-instruct": {
"tokenizer": {
"add_bos_token": True,
"add_eos_token": False,
"bos_token": {
"__type": "AddedToken",
"content": "<begin▁of▁sentence>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False,
},
"clean_up_tokenization_spaces": False,
"eos_token": {
"__type": "AddedToken",
"content": "<end▁of▁sentence>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False,
},
"legacy": True,
"model_max_length": 16384,
"pad_token": {
"__type": "AddedToken",
"content": "<end▁of▁sentence>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False,
},
"sp_model_kwargs": {},
"unk_token": None,
"tokenizer_class": "LlamaTokenizerFast",
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant><think>\\n'}}{% endif %}",
},
"status": "success",
},
}
OPENAI_FINISH_REASONS = ["stop", "length", "function_call", "content_filter", "null"]
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = 60 # 1 minute

View file

@ -183,6 +183,9 @@ def create_fine_tuning_job(
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get(
"client", None
), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
)
# Azure OpenAI
elif custom_llm_provider == "azure":
@ -388,6 +391,7 @@ def cancel_fine_tuning_job(
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":
@ -550,6 +554,7 @@ def list_fine_tuning_jobs(
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":
@ -701,6 +706,7 @@ def retrieve_fine_tuning_job(
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":

View file

@ -0,0 +1,36 @@
"""
Base class for Additional Logging Utils for CustomLoggers
- Health Check for the logging util
- Get Request / Response Payload for the logging util
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
class AdditionalLoggingUtils(ABC):
def __init__(self):
super().__init__()
@abstractmethod
async def async_health_check(self) -> IntegrationHealthCheckStatus:
"""
Check if the service is healthy
"""
pass
@abstractmethod
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
"""
return None

View file

@ -23,6 +23,7 @@ class AthinaLogger:
"context",
"expected_response",
"user_query",
"custom_attributes",
]
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):

View file

@ -1,19 +0,0 @@
"""
Base class for health check integrations
"""
from abc import ABC, abstractmethod
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
class HealthCheckIntegration(ABC):
def __init__(self):
super().__init__()
@abstractmethod
async def async_health_check(self) -> IntegrationHealthCheckStatus:
"""
Check if the service is healthy
"""
pass

View file

@ -38,14 +38,14 @@ from litellm.types.integrations.datadog import *
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import StandardLoggingPayload
from ..base_health_check import HealthCheckIntegration
from ..additional_logging_utils import AdditionalLoggingUtils
DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept
class DataDogLogger(
CustomBatchLogger,
HealthCheckIntegration,
AdditionalLoggingUtils,
):
# Class variables or attributes
def __init__(
@ -543,3 +543,13 @@ class DataDogLogger(
status="unhealthy",
error_message=str(e),
)
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetimeObj],
end_time_utc: Optional[datetimeObj],
) -> Optional[dict]:
raise NotImplementedError(
"Datdog Integration for getting request/response payloads not implemented as yet"
)

View file

@ -1,12 +1,16 @@
import asyncio
import json
import os
import uuid
from datetime import datetime
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from urllib.parse import quote
from litellm._logging import verbose_logger
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
from litellm.proxy._types import CommonProxyErrors
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import StandardLoggingPayload
@ -20,7 +24,7 @@ GCS_DEFAULT_BATCH_SIZE = 2048
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
class GCSBucketLogger(GCSBucketBase):
class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
@ -39,6 +43,7 @@ class GCSBucketLogger(GCSBucketBase):
batch_size=self.batch_size,
flush_interval=self.flush_interval,
)
AdditionalLoggingUtils.__init__(self)
if premium_user is not True:
raise ValueError(
@ -150,11 +155,16 @@ class GCSBucketLogger(GCSBucketBase):
"""
Get the object name to use for the current payload
"""
current_date = datetime.now().strftime("%Y-%m-%d")
current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
if logging_payload.get("error_str", None) is not None:
object_name = f"{current_date}/failure-{uuid.uuid4().hex}"
object_name = self._generate_failure_object_name(
request_date_str=current_date,
)
else:
object_name = f"{current_date}/{response_obj.get('id', '')}"
object_name = self._generate_success_object_name(
request_date_str=current_date,
response_id=response_obj.get("id", ""),
)
# used for testing
_litellm_params = kwargs.get("litellm_params", None) or {}
@ -163,3 +173,65 @@ class GCSBucketLogger(GCSBucketBase):
object_name = _metadata["gcs_log_id"]
return object_name
async def get_request_response_payload(
self,
request_id: str,
start_time_utc: Optional[datetime],
end_time_utc: Optional[datetime],
) -> Optional[dict]:
"""
Get the request and response payload for a given `request_id`
Tries current day, next day, and previous day until it finds the payload
"""
if start_time_utc is None:
raise ValueError(
"start_time_utc is required for getting a payload from GCS Bucket"
)
# Try current day, next day, and previous day
dates_to_try = [
start_time_utc,
start_time_utc + timedelta(days=1),
start_time_utc - timedelta(days=1),
]
date_str = None
for date in dates_to_try:
try:
date_str = self._get_object_date_from_datetime(datetime_obj=date)
object_name = self._generate_success_object_name(
request_date_str=date_str,
response_id=request_id,
)
encoded_object_name = quote(object_name, safe="")
response = await self.download_gcs_object(encoded_object_name)
if response is not None:
loaded_response = json.loads(response)
return loaded_response
except Exception as e:
verbose_logger.debug(
f"Failed to fetch payload for date {date_str}: {str(e)}"
)
continue
return None
def _generate_success_object_name(
self,
request_date_str: str,
response_id: str,
) -> str:
return f"{request_date_str}/{response_id}"
def _generate_failure_object_name(
self,
request_date_str: str,
) -> str:
return f"{request_date_str}/failure-{uuid.uuid4().hex}"
def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
return datetime_obj.strftime("%Y-%m-%d")
async def async_health_check(self) -> IntegrationHealthCheckStatus:
raise NotImplementedError("GCS Bucket does not support health check")

View file

@ -3,7 +3,8 @@
import copy
import os
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from packaging.version import Version
@ -13,9 +14,16 @@ from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.langfuse import *
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
RerankResponse,
StandardLoggingPayload,
StandardLoggingPromptManagementMetadata,
TextCompletionResponse,
TranscriptionResponse,
)
if TYPE_CHECKING:
@ -150,19 +158,29 @@ class LangFuseLogger:
return metadata
def _old_log_event( # noqa: PLR0915
def log_event_on_langfuse(
self,
kwargs,
response_obj,
start_time,
end_time,
user_id,
print_verbose,
level="DEFAULT",
status_message=None,
kwargs: dict,
response_obj: Union[
None,
dict,
EmbeddingResponse,
ModelResponse,
TextCompletionResponse,
ImageResponse,
TranscriptionResponse,
RerankResponse,
HttpxBinaryResponseContent,
],
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
user_id: Optional[str] = None,
level: str = "DEFAULT",
status_message: Optional[str] = None,
) -> dict:
# Method definition
"""
Logs a success or error event on Langfuse
"""
try:
verbose_logger.debug(
f"Langfuse Logging - Enters logging function for model {kwargs}"
@ -198,66 +216,13 @@ class LangFuseLogger:
# if casting value to str fails don't block logging
pass
# end of processing langfuse ########################
if (
level == "ERROR"
and status_message is not None
and isinstance(status_message, str)
):
input = prompt
output = status_message
elif response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
input = prompt
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt
output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.HttpxBinaryResponseContent
):
input = prompt
output = "speech-output"
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = response_obj.choices[0].text
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj["data"]
elif response_obj is not None and isinstance(
response_obj, litellm.TranscriptionResponse
):
input = prompt
output = response_obj["text"]
elif response_obj is not None and isinstance(
response_obj, litellm.RerankResponse
):
input = prompt
output = response_obj.results
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "_arealtime"
and response_obj is not None
and isinstance(response_obj, list)
):
input = kwargs.get("input")
output = response_obj
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint"
and response_obj is not None
and isinstance(response_obj, dict)
):
input = prompt
output = response_obj.get("response", "")
input, output = self._get_langfuse_input_output_content(
kwargs=kwargs,
response_obj=response_obj,
prompt=prompt,
level=level,
status_message=status_message,
)
verbose_logger.debug(
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}"
)
@ -265,31 +230,30 @@ class LangFuseLogger:
generation_id = None
if self._is_langfuse_v2():
trace_id, generation_id = self._log_langfuse_v2(
user_id,
metadata,
litellm_params,
output,
start_time,
end_time,
kwargs,
optional_params,
input,
response_obj,
level,
print_verbose,
litellm_call_id,
user_id=user_id,
metadata=metadata,
litellm_params=litellm_params,
output=output,
start_time=start_time,
end_time=end_time,
kwargs=kwargs,
optional_params=optional_params,
input=input,
response_obj=response_obj,
level=level,
litellm_call_id=litellm_call_id,
)
elif response_obj is not None:
self._log_langfuse_v1(
user_id,
metadata,
output,
start_time,
end_time,
kwargs,
optional_params,
input,
response_obj,
user_id=user_id,
metadata=metadata,
output=output,
start_time=start_time,
end_time=end_time,
kwargs=kwargs,
optional_params=optional_params,
input=input,
response_obj=response_obj,
)
verbose_logger.debug(
f"Langfuse Layer Logging - final response object: {response_obj}"
@ -303,11 +267,108 @@ class LangFuseLogger:
)
return {"trace_id": None, "generation_id": None}
def _get_langfuse_input_output_content(
self,
kwargs: dict,
response_obj: Union[
None,
dict,
EmbeddingResponse,
ModelResponse,
TextCompletionResponse,
ImageResponse,
TranscriptionResponse,
RerankResponse,
HttpxBinaryResponseContent,
],
prompt: dict,
level: str,
status_message: Optional[str],
) -> Tuple[Optional[dict], Optional[Union[str, dict, list]]]:
"""
Get the input and output content for Langfuse logging
Args:
kwargs: The keyword arguments passed to the function
response_obj: The response object returned by the function
prompt: The prompt used to generate the response
level: The level of the log message
status_message: The status message of the log message
Returns:
input: The input content for Langfuse logging
output: The output content for Langfuse logging
"""
input = None
output: Optional[Union[str, dict, List[Any]]] = None
if (
level == "ERROR"
and status_message is not None
and isinstance(status_message, str)
):
input = prompt
output = status_message
elif response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
input = prompt
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt
output = self._get_chat_content_for_langfuse(response_obj)
elif response_obj is not None and isinstance(
response_obj, litellm.HttpxBinaryResponseContent
):
input = prompt
output = "speech-output"
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = self._get_text_completion_content_for_langfuse(response_obj)
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj.get("data", None)
elif response_obj is not None and isinstance(
response_obj, litellm.TranscriptionResponse
):
input = prompt
output = response_obj.get("text", None)
elif response_obj is not None and isinstance(
response_obj, litellm.RerankResponse
):
input = prompt
output = response_obj.results
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "_arealtime"
and response_obj is not None
and isinstance(response_obj, list)
):
input = kwargs.get("input")
output = response_obj
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint"
and response_obj is not None
and isinstance(response_obj, dict)
):
input = prompt
output = response_obj.get("response", "")
return input, output
async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
self, kwargs, response_obj, start_time, end_time, user_id
):
"""
TODO: support async calls when langfuse is truly async
Langfuse SDK uses a background thread to log events
This approach does not impact latency and runs in the background
"""
def _is_langfuse_v2(self):
@ -361,19 +422,18 @@ class LangFuseLogger:
def _log_langfuse_v2( # noqa: PLR0915
self,
user_id,
metadata,
litellm_params,
output,
start_time,
end_time,
kwargs,
optional_params,
input,
user_id: Optional[str],
metadata: dict,
litellm_params: dict,
output: Optional[Union[str, dict, list]],
start_time: Optional[datetime],
end_time: Optional[datetime],
kwargs: dict,
optional_params: dict,
input: Optional[dict],
response_obj,
level,
print_verbose,
litellm_call_id,
level: str,
litellm_call_id: Optional[str],
) -> tuple:
verbose_logger.debug("Langfuse Layer Logging - logging to langfuse v2")
@ -657,6 +717,31 @@ class LangFuseLogger:
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
return None, None
@staticmethod
def _get_chat_content_for_langfuse(
response_obj: ModelResponse,
):
"""
Get the chat content for Langfuse logging
"""
if response_obj.choices and len(response_obj.choices) > 0:
output = response_obj["choices"][0]["message"].json()
return output
else:
return None
@staticmethod
def _get_text_completion_content_for_langfuse(
response_obj: TextCompletionResponse,
):
"""
Get the text completion content for Langfuse logging
"""
if response_obj.choices and len(response_obj.choices) > 0:
return response_obj.choices[0].text
else:
return None
@staticmethod
def _get_langfuse_tags(
standard_logging_object: Optional[StandardLoggingPayload],

View file

@ -247,13 +247,12 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
standard_callback_dynamic_params=standard_callback_dynamic_params,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
langfuse_logger_to_use._old_log_event(
langfuse_logger_to_use.log_event_on_langfuse(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=None,
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
@ -271,12 +270,11 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
)
if standard_logging_object is None:
return
langfuse_logger_to_use._old_log_event(
langfuse_logger_to_use.log_event_on_langfuse(
start_time=start_time,
end_time=end_time,
response_obj=None,
user_id=kwargs.get("user", None),
print_verbose=None,
status_message=standard_logging_object["error_str"],
level="ERROR",
kwargs=kwargs,

View file

@ -118,6 +118,7 @@ class PagerDutyAlerting(SlackAlerting):
user_api_key_user_id=_meta.get("user_api_key_user_id"),
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
user_api_key_user_email=_meta.get("user_api_key_user_email"),
)
)
@ -195,6 +196,7 @@ class PagerDutyAlerting(SlackAlerting):
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_user_email=user_api_key_dict.user_email,
)
)

View file

@ -423,6 +423,7 @@ class PrometheusLogger(CustomLogger):
team=user_api_team,
team_alias=user_api_team_alias,
user=user_id,
user_email=standard_logging_payload["metadata"]["user_api_key_user_email"],
status_code="200",
model=model,
litellm_model_name=model,
@ -806,6 +807,7 @@ class PrometheusLogger(CustomLogger):
enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id,
user=user_api_key_dict.user_id,
user_email=user_api_key_dict.user_email,
hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias,
team=user_api_key_dict.team_id,
@ -853,6 +855,7 @@ class PrometheusLogger(CustomLogger):
team=user_api_key_dict.team_id,
team_alias=user_api_key_dict.team_alias,
user=user_api_key_dict.user_id,
user_email=user_api_key_dict.user_email,
status_code="200",
)
_labels = prometheus_label_factory(

View file

@ -223,6 +223,7 @@ def exception_type( # type: ignore # noqa: PLR0915
"Request Timeout Error" in error_str
or "Request timed out" in error_str
or "Timed out generating response" in error_str
or "The read operation timed out" in error_str
):
exception_mapping_worked = True

View file

@ -121,21 +121,26 @@ def get_supported_openai_params( # noqa: PLR0915
)
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
if request_type == "chat_completion":
if model.startswith("meta/"):
return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params(model=model)
if model.startswith("codestral"):
elif model.startswith("codestral"):
return (
litellm.CodestralTextCompletionConfig().get_supported_openai_params(
model=model
)
)
if model.startswith("claude"):
elif model.startswith("claude"):
return litellm.VertexAIAnthropicConfig().get_supported_openai_params(
model=model
)
return litellm.VertexGeminiConfig().get_supported_openai_params(model=model)
elif model.startswith("gemini"):
return litellm.VertexGeminiConfig().get_supported_openai_params(
model=model
)
else:
return litellm.VertexAILlama3Config().get_supported_openai_params(
model=model
)
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker":

View file

@ -199,6 +199,7 @@ class Logging(LiteLLMLoggingBaseClass):
dynamic_async_failure_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = None,
applied_guardrails: Optional[List[str]] = None,
kwargs: Optional[Dict] = None,
):
_input: Optional[str] = messages # save original value of messages
@ -271,6 +272,7 @@ class Logging(LiteLLMLoggingBaseClass):
"litellm_call_id": litellm_call_id,
"input": _input,
"litellm_params": litellm_params,
"applied_guardrails": applied_guardrails,
}
def process_dynamic_callbacks(self):
@ -1247,13 +1249,12 @@ class Logging(LiteLLMLoggingBaseClass):
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
if langfuse_logger_to_use is not None:
_response = langfuse_logger_to_use._old_log_event(
_response = langfuse_logger_to_use.log_event_on_langfuse(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
@ -1957,12 +1958,11 @@ class Logging(LiteLLMLoggingBaseClass):
standard_callback_dynamic_params=self.standard_callback_dynamic_params,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
_response = langfuse_logger_to_use._old_log_event(
_response = langfuse_logger_to_use.log_event_on_langfuse(
start_time=start_time,
end_time=end_time,
response_obj=None,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
status_message=str(exception),
level="ERROR",
kwargs=self.model_call_details,
@ -2854,6 +2854,7 @@ class StandardLoggingPayloadSetup:
metadata: Optional[Dict[str, Any]],
litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None,
applied_guardrails: Optional[List[str]] = None,
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -2868,6 +2869,7 @@ class StandardLoggingPayloadSetup:
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
"""
prompt_management_metadata: Optional[
StandardLoggingPromptManagementMetadata
] = None
@ -2892,11 +2894,13 @@ class StandardLoggingPayloadSetup:
user_api_key_org_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
user_api_key_user_email=None,
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
@ -3195,6 +3199,7 @@ def get_standard_logging_object_payload(
metadata=metadata,
litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None),
applied_guardrails=kwargs.get("applied_guardrails", None),
)
_request_body = proxy_server_request.get("body", {})
@ -3324,12 +3329,14 @@ def get_standard_logging_metadata(
user_api_key_team_id=None,
user_api_key_org_id=None,
user_api_key_user_id=None,
user_api_key_user_email=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=None,
applied_guardrails=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys

View file

@ -1,9 +1,10 @@
import asyncio
import json
import re
import time
import traceback
import uuid
from typing import Dict, Iterable, List, Literal, Optional, Union
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
import litellm
from litellm._logging import verbose_logger
@ -221,6 +222,27 @@ def _handle_invalid_parallel_tool_calls(
return tool_calls
def _parse_content_for_reasoning(
message_text: Optional[str],
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the content for reasoning
Returns:
- reasoning_content: The content of the reasoning
- content: The content of the message
"""
if not message_text:
return None, message_text
reasoning_match = re.match(r"<think>(.*?)</think>(.*)", message_text, re.DOTALL)
if reasoning_match:
return reasoning_match.group(1), reasoning_match.group(2)
return None, message_text
class LiteLLMResponseObjectHandler:
@staticmethod
@ -432,8 +454,20 @@ def convert_to_model_response_object( # noqa: PLR0915
for field in choice["message"].keys():
if field not in message_keys:
provider_specific_fields[field] = choice["message"][field]
# Handle reasoning models that display `reasoning_content` within `content`
reasoning_content, content = _parse_content_for_reasoning(
choice["message"].get("content")
)
if reasoning_content:
provider_specific_fields["reasoning_content"] = (
reasoning_content
)
message = Message(
content=choice["message"].get("content", None),
content=content,
role=choice["message"]["role"] or "assistant",
function_call=choice["message"].get("function_call", None),
tool_calls=tool_calls,

View file

@ -1,7 +1,8 @@
from typing import Callable, List, Union
from typing import Callable, List, Set, Union
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
from litellm.integrations.custom_logger import CustomLogger
@ -85,6 +86,21 @@ class LoggingCallbackManager:
callback=callback, parent_list=litellm._async_failure_callback
)
def remove_callback_from_list_by_object(
self, callback_list, obj
):
"""
Remove callbacks that are methods of a particular object (e.g., router cleanup)
"""
if not isinstance(callback_list, list): # Not list -> do nothing
return
remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj]
for c in remove_list:
callback_list.remove(c)
def _add_string_callback_to_list(
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
):
@ -205,3 +221,36 @@ class LoggingCallbackManager:
litellm._async_success_callback = []
litellm._async_failure_callback = []
litellm.callbacks = []
def _get_all_callbacks(self) -> List[Union[CustomLogger, Callable, str]]:
"""
Get all callbacks from litellm.callbacks, litellm.success_callback, litellm.failure_callback, litellm._async_success_callback, litellm._async_failure_callback
"""
return (
litellm.callbacks
+ litellm.success_callback
+ litellm.failure_callback
+ litellm._async_success_callback
+ litellm._async_failure_callback
)
def get_active_additional_logging_utils_from_custom_logger(
self,
) -> Set[AdditionalLoggingUtils]:
"""
Get all custom loggers that are instances of the given class type
Args:
class_type: The class type to match against (e.g., AdditionalLoggingUtils)
Returns:
Set[CustomLogger]: Set of custom loggers that are instances of the given class type
"""
all_callbacks = self._get_all_callbacks()
matched_callbacks: Set[AdditionalLoggingUtils] = set()
for callback in all_callbacks:
if isinstance(callback, CustomLogger) and isinstance(
callback, AdditionalLoggingUtils
):
matched_callbacks.add(callback)
return matched_callbacks

View file

@ -325,26 +325,6 @@ def phind_codellama_pt(messages):
return prompt
known_tokenizer_config = {
"mistralai/Mistral-7B-Instruct-v0.1": {
"tokenizer": {
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"bos_token": "<s>",
"eos_token": "</s>",
},
"status": "success",
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"tokenizer": {
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
"bos_token": "<|begin_of_text|>",
"eos_token": "",
},
"status": "success",
},
}
def hf_chat_template( # noqa: PLR0915
model: str, messages: list, chat_template: Optional[Any] = None
):
@ -378,11 +358,11 @@ def hf_chat_template( # noqa: PLR0915
else:
return {"status": "failure"}
if model in known_tokenizer_config:
tokenizer_config = known_tokenizer_config[model]
if model in litellm.known_tokenizer_config:
tokenizer_config = litellm.known_tokenizer_config[model]
else:
tokenizer_config = _get_tokenizer_config(model)
known_tokenizer_config.update({model: tokenizer_config})
litellm.known_tokenizer_config.update({model: tokenizer_config})
if (
tokenizer_config["status"] == "failure"
@ -475,6 +455,12 @@ def hf_chat_template( # noqa: PLR0915
) # don't use verbose_logger.exception, if exception is raised
def deepseek_r1_pt(messages):
return hf_chat_template(
model="deepseek-r1/deepseek-r1-7b-instruct", messages=messages
)
# Anthropic template
def claude_2_1_pt(
messages: list,
@ -1421,6 +1407,8 @@ def anthropic_messages_pt( # noqa: PLR0915
)
user_content.append(_content_element)
elif m.get("type", "") == "document":
user_content.append(cast(AnthropicMessagesDocumentParam, m))
elif isinstance(user_message_types_block["content"], str):
_anthropic_content_text_element: AnthropicMessagesTextParam = {
"type": "text",

View file

@ -0,0 +1,81 @@
from typing import Any, Dict, Optional, Set
class SensitiveDataMasker:
def __init__(
self,
sensitive_patterns: Optional[Set[str]] = None,
visible_prefix: int = 4,
visible_suffix: int = 4,
mask_char: str = "*",
):
self.sensitive_patterns = sensitive_patterns or {
"password",
"secret",
"key",
"token",
"auth",
"credential",
"access",
"private",
"certificate",
}
self.visible_prefix = visible_prefix
self.visible_suffix = visible_suffix
self.mask_char = mask_char
def _mask_value(self, value: str) -> str:
if not value or len(str(value)) < (self.visible_prefix + self.visible_suffix):
return value
value_str = str(value)
masked_length = len(value_str) - (self.visible_prefix + self.visible_suffix)
return f"{value_str[:self.visible_prefix]}{self.mask_char * masked_length}{value_str[-self.visible_suffix:]}"
def is_sensitive_key(self, key: str) -> bool:
key_lower = str(key).lower()
result = any(pattern in key_lower for pattern in self.sensitive_patterns)
return result
def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10
) -> Dict[str, Any]:
if depth >= max_depth:
return data
masked_data: Dict[str, Any] = {}
for k, v in data.items():
try:
if isinstance(v, dict):
masked_data[k] = self.mask_dict(v, depth + 1)
elif hasattr(v, "__dict__") and not isinstance(v, type):
masked_data[k] = self.mask_dict(vars(v), depth + 1)
elif self.is_sensitive_key(k):
str_value = str(v) if v is not None else ""
masked_data[k] = self._mask_value(str_value)
else:
masked_data[k] = (
v if isinstance(v, (int, float, bool, str)) else str(v)
)
except Exception:
masked_data[k] = "<unable to serialize>"
return masked_data
# Usage example:
"""
masker = SensitiveDataMasker()
data = {
"api_key": "sk-1234567890abcdef",
"redis_password": "very_secret_pass",
"port": 6379
}
masked = masker.mask_dict(data)
# Result: {
# "api_key": "sk-1****cdef",
# "redis_password": "very****pass",
# "port": 6379
# }
"""

View file

@ -809,7 +809,10 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
if response_obj.get("provider_specific_fields") is not None:
completion_obj["provider_specific_fields"] = response_obj[
"provider_specific_fields"
]
model_response.choices[0].delta = Delta(**completion_obj)
_index: Optional[int] = completion_obj.get("index")
if _index is not None:

View file

@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
import copy
import json
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import httpx # type: ignore
@ -506,6 +506,29 @@ class ModelResponseIterator:
return usage_block
def _content_block_delta_helper(self, chunk: dict):
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
provider_specific_fields = {}
content_block = ContentBlockDelta(**chunk) # type: ignore
self.content_blocks.append(content_block)
if "text" in content_block["delta"]:
text = content_block["delta"]["text"]
elif "partial_json" in content_block["delta"]:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": content_block["delta"]["partial_json"],
},
"index": self.tool_index,
}
elif "citation" in content_block["delta"]:
provider_specific_fields["citation"] = content_block["delta"]["citation"]
return text, tool_use, provider_specific_fields
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
type_chunk = chunk.get("type", "") or ""
@ -515,6 +538,7 @@ class ModelResponseIterator:
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields: Dict[str, Any] = {}
index = int(chunk.get("index", 0))
if type_chunk == "content_block_delta":
@ -522,20 +546,9 @@ class ModelResponseIterator:
Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
"""
content_block = ContentBlockDelta(**chunk) # type: ignore
self.content_blocks.append(content_block)
if "text" in content_block["delta"]:
text = content_block["delta"]["text"]
elif "partial_json" in content_block["delta"]:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": content_block["delta"]["partial_json"],
},
"index": self.tool_index,
}
text, tool_use, provider_specific_fields = (
self._content_block_delta_helper(chunk=chunk)
)
elif type_chunk == "content_block_start":
"""
event: content_block_start
@ -628,6 +641,9 @@ class ModelResponseIterator:
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=(
provider_specific_fields if provider_specific_fields else None
),
)
return returned_chunk

View file

@ -70,7 +70,7 @@ class AnthropicConfig(BaseConfig):
metadata: Optional[dict] = None,
system: Optional[str] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@ -628,6 +628,7 @@ class AnthropicConfig(BaseConfig):
)
else:
text_content = ""
citations: List[Any] = []
tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text":
@ -645,10 +646,14 @@ class AnthropicConfig(BaseConfig):
index=idx,
)
)
## CITATIONS
if content.get("citations", None) is not None:
citations.append(content["citations"])
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
provider_specific_fields={"citations": citations},
)
## HANDLE JSON MODE - anthropic returns single function call

View file

@ -72,7 +72,7 @@ class AnthropicTextConfig(BaseConfig):
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -5,10 +5,11 @@ import time
from typing import Any, Callable, Dict, List, Literal, Optional, Union
import httpx # type: ignore
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
import litellm
from litellm.caching.caching import DualCache
from litellm.constants import DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
@ -98,14 +99,6 @@ class AzureOpenAIAssistantsAPIConfig:
def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
# "azure_endpoint": api_base,
# "azure_deployment": model,
# "http_client": litellm.client_session,
# "max_retries": max_retries,
# "timeout": timeout,
# }
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
@ -312,6 +305,7 @@ class AzureChatCompletion(BaseLLM):
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
start_time = time.time()
try:
raw_response = await azure_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
@ -320,6 +314,11 @@ class AzureChatCompletion(BaseLLM):
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except APITimeoutError as e:
end_time = time.time()
time_delta = round(end_time - start_time, 2)
e.message += f" - timeout value={timeout}, time taken={time_delta} seconds"
raise e
except Exception as e:
raise e
@ -353,7 +352,9 @@ class AzureChatCompletion(BaseLLM):
status_code=422, message="Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2)
max_retries = optional_params.pop("max_retries", None)
if max_retries is None:
max_retries = DEFAULT_MAX_RETRIES
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
### CHECK IF CLOUDFLARE AI GATEWAY ###
@ -415,6 +416,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout,
client=client,
max_retries=max_retries,
)
else:
return self.acompletion(
@ -430,6 +432,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
logging_obj=logging_obj,
max_retries=max_retries,
convert_tool_call_to_json_mode=json_mode,
)
elif "stream" in optional_params and optional_params["stream"] is True:
@ -445,6 +448,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout,
client=client,
max_retries=max_retries,
)
else:
## LOGGING
@ -553,6 +557,7 @@ class AzureChatCompletion(BaseLLM):
dynamic_params: bool,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
max_retries: int,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
@ -560,12 +565,6 @@ class AzureChatCompletion(BaseLLM):
):
response = None
try:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -649,6 +648,7 @@ class AzureChatCompletion(BaseLLM):
)
raise AzureOpenAIError(status_code=500, message=str(e))
except Exception as e:
message = getattr(e, "message", str(e))
## LOGGING
logging_obj.post_call(
input=data["messages"],
@ -659,7 +659,7 @@ class AzureChatCompletion(BaseLLM):
if hasattr(e, "status_code"):
raise e
else:
raise AzureOpenAIError(status_code=500, message=str(e))
raise AzureOpenAIError(status_code=500, message=message)
def streaming(
self,
@ -671,15 +671,11 @@ class AzureChatCompletion(BaseLLM):
data: dict,
model: str,
timeout: Any,
max_retries: int,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -742,6 +738,7 @@ class AzureChatCompletion(BaseLLM):
data: dict,
model: str,
timeout: Any,
max_retries: int,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
@ -753,7 +750,7 @@ class AzureChatCompletion(BaseLLM):
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.aclient_session,
"max_retries": data.pop("max_retries", 2),
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
@ -807,10 +804,11 @@ class AzureChatCompletion(BaseLLM):
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
message = getattr(e, "message", str(e))
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
status_code=status_code, message=message, headers=error_headers
)
async def aembedding(

View file

@ -98,6 +98,7 @@ class AzureOpenAIConfig(BaseConfig):
"seed",
"extra_headers",
"parallel_tool_calls",
"prediction",
]
def _is_response_format_supported_model(self, model: str) -> bool:
@ -113,6 +114,17 @@ class AzureOpenAIConfig(BaseConfig):
return False
def _is_response_format_supported_api_version(
self, api_version_year: str, api_version_month: str
) -> bool:
"""
- check if api_version is supported for response_format
"""
is_supported = int(api_version_year) <= 2024 and int(api_version_month) >= 8
return is_supported
def map_openai_params(
self,
non_default_params: dict,
@ -171,13 +183,20 @@ class AzureOpenAIConfig(BaseConfig):
_is_response_format_supported_model = (
self._is_response_format_supported_model(model)
)
should_convert_response_format_to_tool = (
api_version_year <= "2024" and api_version_month < "08"
) or not _is_response_format_supported_model
is_response_format_supported_api_version = (
self._is_response_format_supported_api_version(
api_version_year, api_version_month
)
)
is_response_format_supported = (
is_response_format_supported_api_version
and _is_response_format_supported_model
)
optional_params = self._add_response_format_to_tools(
optional_params=optional_params,
value=value,
should_convert_response_format_to_tool=should_convert_response_format_to_tool,
is_response_format_supported=is_response_format_supported,
)
elif param == "tools" and isinstance(value, list):
optional_params.setdefault("tools", [])

View file

@ -131,6 +131,7 @@ class AzureTextCompletion(BaseLLM):
timeout=timeout,
client=client,
logging_obj=logging_obj,
max_retries=max_retries,
)
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
@ -236,17 +237,12 @@ class AzureTextCompletion(BaseLLM):
timeout: Any,
model_response: ModelResponse,
logging_obj: Any,
max_retries: int,
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
):
response = None
try:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,

View file

@ -34,6 +34,17 @@ class BaseLLMModelInfo(ABC):
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass
@staticmethod
@abstractmethod
def get_base_model(model: str) -> Optional[str]:
"""
Returns the base model name from the given model name.
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
This function will return `anthropic.claude-3-opus-20240229-v1:0`
"""
pass
def _dict_to_response_format_helper(
response_format: dict, ref_template: Optional[str] = None

View file

@ -20,6 +20,7 @@ from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionToolChoiceFunctionParam,
@ -27,9 +28,6 @@ from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
@ -163,7 +161,7 @@ class BaseConfig(ABC):
self,
optional_params: dict,
value: dict,
should_convert_response_format_to_tool: bool,
is_response_format_supported: bool,
) -> dict:
"""
Follow similar approach to anthropic - translate to a single tool call.
@ -183,7 +181,8 @@ class BaseConfig(ABC):
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
if json_schema and should_convert_response_format_to_tool:
if json_schema and not is_response_format_supported:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(

View file

@ -52,6 +52,7 @@ class BaseAWSLLM:
"aws_role_name",
"aws_web_identity_token",
"aws_sts_endpoint",
"aws_bedrock_runtime_endpoint",
]
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:

View file

@ -33,14 +33,7 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import add_dummy_tool, has_tool_call_blocks
from ..common_utils import (
AmazonBedrockGlobalConfig,
BedrockError,
get_bedrock_tool_name,
)
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
class AmazonConverseConfig(BaseConfig):
@ -63,7 +56,7 @@ class AmazonConverseConfig(BaseConfig):
topP: Optional[int] = None,
topK: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@ -104,7 +97,7 @@ class AmazonConverseConfig(BaseConfig):
]
## Filter out 'cross-region' from model name
base_model = self._get_base_model(model)
base_model = BedrockModelInfo.get_base_model(model)
if (
base_model.startswith("anthropic")
@ -112,6 +105,7 @@ class AmazonConverseConfig(BaseConfig):
or base_model.startswith("cohere")
or base_model.startswith("meta.llama3-1")
or base_model.startswith("meta.llama3-2")
or base_model.startswith("meta.llama3-3")
or base_model.startswith("amazon.nova")
):
supported_params.append("tools")
@ -341,9 +335,9 @@ class AmazonConverseConfig(BaseConfig):
if "top_k" in inference_params:
inference_params["topK"] = inference_params.pop("top_k")
return InferenceConfig(**inference_params)
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
base_model = self._get_base_model(model)
base_model = BedrockModelInfo.get_base_model(model)
val_top_k = None
if "topK" in inference_params:
@ -352,11 +346,11 @@ class AmazonConverseConfig(BaseConfig):
val_top_k = inference_params.pop("top_k")
if val_top_k:
if (base_model.startswith("anthropic")):
if base_model.startswith("anthropic"):
return {"top_k": val_top_k}
if base_model.startswith("amazon.nova"):
return {'inferenceConfig': {"topK": val_top_k}}
return {"inferenceConfig": {"topK": val_top_k}}
return {}
def _transform_request_helper(
@ -393,15 +387,25 @@ class AmazonConverseConfig(BaseConfig):
) + ["top_k"]
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
total_supported_params = (
supported_converse_params
+ supported_tool_call_params
+ supported_guardrail_params
)
inference_params.pop("json_mode", None) # used for handling json_schema
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}
additional_request_params = {
k: v for k, v in inference_params.items() if k not in total_supported_params
}
inference_params = {
k: v for k, v in inference_params.items() if k in total_supported_params
}
# Only set the topK value in for models that support it
additional_request_params.update(self._handle_top_k_value(model, inference_params))
additional_request_params.update(
self._handle_top_k_value(model, inference_params)
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
@ -679,41 +683,6 @@ class AmazonConverseConfig(BaseConfig):
return model_response
def _supported_cross_region_inference_region(self) -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]
def _get_base_model(self, model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if potential_region in self._supported_cross_region_inference_region():
return model.split(".", 1)[1]
elif (
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:

View file

@ -1,5 +1,5 @@
"""
Manages calling Bedrock's `/converse` API + `/invoke` API
TODO: DELETE FILE. Bedrock LLM is no longer used. Goto `litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py`
"""
import copy
@ -40,6 +40,9 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
parse_xml_params,
prompt_factory,
)
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -103,7 +106,7 @@ class AmazonCohereChatConfig:
stop_sequences: Optional[str] = None,
raw_prompting: Optional[bool] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@ -177,6 +180,7 @@ async def make_call(
logging_obj: Logging,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
):
try:
if client is None:
@ -214,6 +218,14 @@ async def make_call(
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=False,
)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(
@ -248,6 +260,7 @@ def make_sync_call(
logging_obj: Logging,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
):
try:
if client is None:
@ -283,6 +296,12 @@ def make_sync_call(
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=True,
)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@ -1323,7 +1342,7 @@ class AWSEventStreamDecoder:
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
######## /bedrock/converse mappings ###############
elif (
"contentBlockIndex" in chunk_data
or "stopReason" in chunk_data
@ -1331,6 +1350,11 @@ class AWSEventStreamDecoder:
or "trace" in chunk_data
):
return self.converse_chunk_parser(chunk_data=chunk_data)
######### /bedrock/invoke nova mappings ###############
elif "contentBlockDelta" in chunk_data:
# when using /bedrock/invoke/nova, the chunk_data is nested under "contentBlockDelta"
_chunk_data = chunk_data.get("contentBlockDelta", None)
return self.converse_chunk_parser(chunk_data=_chunk_data)
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
@ -1429,6 +1453,27 @@ class AWSEventStreamDecoder:
return chunk.decode() # type: ignore[no-any-return]
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
def __init__(
self,
model: str,
sync_stream: bool,
) -> None:
"""
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
"""
super().__init__(model=model)
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=sync_stream,
)
def _chunk_parser(self, chunk_data: dict) -> GChunk:
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
class MockResponseIterator: # for returning ai21 streaming responses
def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response

View file

@ -46,7 +46,7 @@ class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -28,7 +28,7 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
temperature: Optional[float] = None,
return_likelihood: Optional[str] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -28,7 +28,7 @@ class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -33,7 +33,7 @@ class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
top_k: Optional[float] = None,
stop: Optional[List[str]] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -0,0 +1,70 @@
"""
Handles transforming requests for `bedrock/invoke/{nova} models`
Inherits from `AmazonConverseConfig`
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
"""
from typing import List
import litellm
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
from litellm.types.llms.openai import AllMessageValues
class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
"""
Config for sending `nova` requests to `/bedrock/invoke/`
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_transformed_nova_request = super().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
**_transformed_nova_request
)
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
bedrock_invoke_nova_request = self._filter_allowed_fields(
_bedrock_invoke_nova_request
)
return bedrock_invoke_nova_request
def _filter_allowed_fields(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> dict:
"""
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
"""
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
return {
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
}
def _remove_empty_system_messages(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> None:
"""
In-place remove empty `system` messages from the request.
/bedrock/invoke/ does not allow empty `system` messages.
"""
_system_message = bedrock_invoke_nova_request.get("system", None)
if isinstance(_system_message, list) and len(_system_message) == 0:
bedrock_invoke_nova_request.pop("system", None)
return

View file

@ -33,7 +33,7 @@ class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -34,7 +34,7 @@ class AmazonAnthropicConfig:
top_p: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -1,61 +1,34 @@
import types
from typing import List, Optional
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
import litellm
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonAnthropicClaude3Config:
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
"""
Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
Supported Params for the Amazon / Anthropic Claude 3 models:
- `max_tokens` Required (integer) max tokens. Default is 4096
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
"""
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
anthropic_version: Optional[str] = "bedrock-2023-05-31"
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
anthropic_version: str = "bedrock-2023-05-31"
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
def get_supported_openai_params(self, model: str):
return [
"max_tokens",
"max_completion_tokens",
@ -68,7 +41,13 @@ class AmazonAnthropicClaude3Config:
"extra_headers",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
@ -83,3 +62,53 @@ class AmazonAnthropicClaude3Config:
if param == "top_p":
optional_params["top_p"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_anthropic_request = litellm.AnthropicConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
_anthropic_request.pop("model", None)
if "anthropic_version" not in _anthropic_request:
_anthropic_request["anthropic_version"] = self.anthropic_version
return _anthropic_request
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return litellm.AnthropicConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)

View file

@ -2,22 +2,19 @@ import copy
import json
import time
import urllib.parse
import uuid
from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.prompt_templates.factory import (
cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
deepseek_r1_pt,
prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
@ -91,7 +88,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
optional_params=optional_params,
)
### SET RUNTIME ENDPOINT ###
aws_bedrock_runtime_endpoint = optional_params.pop(
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
@ -129,15 +126,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
extra_headers = optional_params.pop("extra_headers", None)
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
extra_headers = optional_params.get("extra_headers", None)
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
aws_access_key_id = optional_params.get("aws_access_key_id", None)
aws_session_token = optional_params.get("aws_session_token", None)
aws_role_name = optional_params.get("aws_role_name", None)
aws_session_name = optional_params.get("aws_session_name", None)
aws_profile_name = optional_params.get("aws_profile_name", None)
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
aws_region_name = self._get_aws_region_name(optional_params)
credentials: Credentials = self.get_credentials(
@ -171,7 +168,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
return dict(request.headers)
def transform_request( # noqa: PLR0915
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
@ -182,11 +179,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
custom_prompt_dict: dict = litellm_params.pop("custom_prompt_dict", None) or {}
hf_model_name = litellm_params.get("hf_model_name", None)
provider = self.get_bedrock_invoke_provider(model)
prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
model=hf_model_name or model,
messages=messages,
provider=provider,
custom_prompt_dict=custom_prompt_dict,
)
inference_params = copy.deepcopy(optional_params)
inference_params = {
@ -194,7 +195,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
for k, v in inference_params.items()
if k not in self.aws_authentication_params
}
json_schemas: dict = {}
request_data: dict = {}
if provider == "cohere":
if model.startswith("cohere.command-r"):
@ -223,57 +223,21 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
)
request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
# Separate system prompt from rest of message
system_prompt_idx: list[int] = []
system_messages: list[str] = []
for idx, message in enumerate(messages):
if message["role"] == "system" and isinstance(
message["content"], str
):
system_messages.append(message["content"])
system_prompt_idx.append(idx)
if len(system_prompt_idx) > 0:
inference_params["system"] = "\n".join(system_messages)
messages = [
i for j, i in enumerate(messages) if j not in system_prompt_idx
]
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
) # type: ignore
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## Handle Tool Calling
if "tools" in inference_params:
_is_function_call = True
for tool in inference_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
inference_params["system"] = (
inference_params.get("system", "\n")
+ tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
inference_params.pop("tools")
request_data = {"messages": messages, **inference_params}
else:
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
request_data = {"prompt": prompt, **inference_params}
return litellm.AmazonAnthropicClaude3Config().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
elif provider == "nova":
return litellm.AmazonInvokeNovaConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
@ -307,7 +271,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
"inputText": prompt,
"textGenerationConfig": inference_params,
}
elif provider == "meta" or provider == "llama":
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items():
@ -347,6 +311,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
raise BedrockError(
message=raw_response.text, status_code=raw_response.status_code
)
verbose_logger.debug(
"bedrock invoke response % s",
json.dumps(completion_response, indent=4, default=str),
)
provider = self.get_bedrock_invoke_provider(model)
outputText: Optional[str] = None
try:
@ -359,71 +327,36 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
completion_response["generations"][0]["finish_reason"]
)
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
json_schemas: dict = {}
_is_function_call = False
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool[
"function"
].get("parameters", None)
outputText = completion_response.get("content")[0].get("text", None)
if outputText is not None and contains_tag(
"invoke", outputText
): # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags(
"invoke", outputText
)[0].strip()
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
)
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name)
)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
model_response.choices[0].finish_reason = map_finish_reason(
completion_response.get("stop_reason", "")
)
_usage = litellm.Usage(
prompt_tokens=completion_response["usage"]["input_tokens"],
completion_tokens=completion_response["usage"]["output_tokens"],
total_tokens=completion_response["usage"]["input_tokens"]
+ completion_response["usage"]["output_tokens"],
)
setattr(model_response, "usage", _usage)
else:
outputText = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response[
"stop_reason"
]
return litellm.AmazonAnthropicClaude3Config().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
elif provider == "nova":
return litellm.AmazonInvokeNovaConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
elif provider == "ai21":
outputText = (
completion_response.get("completions")[0].get("data").get("text")
)
elif provider == "meta" or provider == "llama":
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
outputText = completion_response["generation"]
elif provider == "mistral":
outputText = completion_response["outputs"][0]["text"]
@ -536,6 +469,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
),
model=model,
custom_llm_provider="bedrock",
@ -569,6 +503,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
),
model=model,
custom_llm_provider="bedrock",
@ -594,10 +529,15 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
"""
Helper function to get the bedrock provider from the model
handles 2 scenarions:
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
handles 3 scenarions:
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
"""
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
_split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
@ -606,6 +546,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
if provider is not None:
return provider
# check if provider == "nova"
if "nova" in model:
return "nova"
return None
@staticmethod
@ -640,16 +584,16 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
else:
modelId = model
modelId = modelId.replace("invoke/", "", 1)
if provider == "llama" and "llama/" in modelId:
modelId = self._get_model_id_for_llama_like_model(modelId)
return modelId
def _get_aws_region_name(self, optional_params: dict) -> str:
"""
Get the AWS region name from the environment variables
"""
aws_region_name = optional_params.pop("aws_region_name", None)
aws_region_name = optional_params.get("aws_region_name", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
@ -725,6 +669,8 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
)
elif provider == "cohere":
prompt, chat_history = cohere_message_pt(messages=messages)
elif provider == "deepseek_r1":
prompt = deepseek_r1_pt(messages=messages)
else:
prompt = ""
for message in messages:

View file

@ -3,11 +3,12 @@ Common utilities used across bedrock chat/embedding/image generation
"""
import os
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret
@ -310,3 +311,68 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
response_tool_name
]
return response_tool_name
class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
@staticmethod
def get_base_model(model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("bedrock/"):
model = model.split("/", 1)[1]
if model.startswith("converse/"):
model = model.split("/", 1)[1]
if model.startswith("invoke/"):
model = model.split("/", 1)[1]
potential_region = model.split(".", 1)[0]
alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
if (
potential_region
in BedrockModelInfo._supported_cross_region_inference_region()
):
return model.split(".", 1)[1]
elif (
alt_potential_region in BedrockModelInfo.all_global_regions
and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]
return model
@staticmethod
def _supported_cross_region_inference_region() -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]
@staticmethod
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
"""
Get the bedrock route for the given model.
"""
base_model = BedrockModelInfo.get_base_model(model)
if "invoke/" in model:
return "invoke"
elif "converse_like" in model:
return "converse_like"
elif "converse/" in model:
return "converse"
elif base_model in litellm.bedrock_converse_models:
return "converse"
return "invoke"

View file

@ -27,7 +27,7 @@ class AmazonTitanG1Config:
def __init__(
self,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -33,7 +33,7 @@ class AmazonTitanV2Config:
def __init__(
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -49,7 +49,7 @@ class AmazonStabilityConfig:
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -45,7 +45,7 @@ class ClarifaiConfig(BaseConfig):
temperature: Optional[int] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -44,7 +44,7 @@ class CloudflareChatConfig(BaseConfig):
max_tokens: Optional[int] = None,
stream: Optional[bool] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -104,7 +104,7 @@ class CohereChatConfig(BaseConfig):
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -86,7 +86,7 @@ class CohereTextConfig(BaseConfig):
return_likelihoods: Optional[str] = None,
logit_bias: Optional[dict] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -1,5 +1,6 @@
import asyncio
import os
import time
from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Optional, Union
import httpx
@ -179,6 +180,7 @@ class AsyncHTTPHandler:
stream: bool = False,
logging_obj: Optional[LiteLLMLoggingObject] = None,
):
start_time = time.time()
try:
if timeout is None:
timeout = self.timeout
@ -207,6 +209,8 @@ class AsyncHTTPHandler:
finally:
await new_client.aclose()
except httpx.TimeoutException as e:
end_time = time.time()
time_delta = round(end_time - start_time, 3)
headers = {}
error_response = getattr(e, "response", None)
if error_response is not None:
@ -214,7 +218,7 @@ class AsyncHTTPHandler:
headers["response_headers-{}".format(key)] = value
raise litellm.Timeout(
message=f"Connection timed out after {timeout} seconds.",
message=f"Connection timed out. Timeout passed={timeout}, time taken={time_delta} seconds",
model="default-model-name",
llm_provider="litellm-httpx-handler",
headers=headers,

View file

@ -37,7 +37,7 @@ class DatabricksConfig(OpenAILikeChatConfig):
stop: Optional[Union[List[str], str]] = None,
n: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -16,7 +16,7 @@ class DatabricksEmbeddingConfig:
)
def __init__(self, instruction: Optional[str] = None) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -145,7 +145,7 @@ class AlephAlphaConfig:
contextual_control_threshold: Optional[int] = None,
control_log_additive: Optional[bool] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -63,7 +63,7 @@ class PalmConfig:
top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -57,7 +57,7 @@ class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -77,7 +77,7 @@ class HuggingfaceChatConfig(BaseConfig):
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -20,6 +20,15 @@ from .common_utils import InfinityError
class InfinityRerankConfig(CohereRerankConfig):
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError("api_base is required for Infinity rerank")
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/rerank"):
api_base = f"{api_base}/rerank"
return api_base
def validate_environment(
self,
headers: dict,

View file

@ -21,7 +21,7 @@ class JinaAIEmbeddingConfig:
def __init__(
self,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -18,7 +18,7 @@ class LmStudioEmbeddingConfig:
def __init__(
self,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -33,7 +33,7 @@ class MaritalkConfig(OpenAIGPTConfig):
tools: Optional[List[dict]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -78,7 +78,7 @@ class NLPCloudConfig(BaseConfig):
num_beams: Optional[int] = None,
num_return_sequences: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -32,7 +32,7 @@ class NvidiaNimEmbeddingConfig:
input_type: Optional[str] = None,
truncate: Optional[str] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@ -58,7 +58,7 @@ class NvidiaNimEmbeddingConfig:
def get_supported_openai_params(
self,
):
return ["encoding_format", "user"]
return ["encoding_format", "user", "dimensions"]
def map_openai_params(
self,
@ -73,6 +73,8 @@ class NvidiaNimEmbeddingConfig:
optional_params["extra_body"].update({"input_type": v})
elif k == "truncate":
optional_params["extra_body"].update({"truncate": v})
else:
optional_params[k] = v
if kwargs is not None:
# pass kwargs in extra_body

View file

@ -117,7 +117,7 @@ class OllamaConfig(BaseConfig):
system: Optional[str] = None,
template: Optional[str] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -105,7 +105,7 @@ class OllamaChatConfig(OpenAIGPTConfig):
system: Optional[str] = None,
template: Optional[str] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -344,6 +344,10 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
or "https://api.openai.com/v1"
)
@staticmethod
def get_base_model(model: str) -> str:
return model
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],

View file

@ -43,23 +43,6 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
"""
return messages
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
if stream is not True:
return False
if model is None:
return True
supported_stream_models = ["o1-mini", "o1-preview"]
for supported_model in supported_stream_models:
if supported_model in model:
return False
return True
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the given model

View file

@ -27,6 +27,7 @@ from typing_extensions import overload
import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.constants import DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
@ -320,6 +321,17 @@ class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _set_dynamic_params_on_client(
self,
client: Union[OpenAI, AsyncOpenAI],
organization: Optional[str] = None,
max_retries: Optional[int] = None,
):
if organization is not None:
client.organization = organization
if max_retries is not None:
client.max_retries = max_retries
def _get_openai_client(
self,
is_async: bool,
@ -327,11 +339,10 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str] = None,
api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
):
args = locals()
if client is None:
if not isinstance(max_retries, int):
raise OpenAIError(
@ -364,7 +375,6 @@ class OpenAIChatCompletion(BaseLLM):
organization=organization,
)
else:
_new_client = OpenAI(
api_key=api_key,
base_url=api_base,
@ -383,6 +393,11 @@ class OpenAIChatCompletion(BaseLLM):
return _new_client
else:
self._set_dynamic_params_on_client(
client=client,
organization=organization,
max_retries=max_retries,
)
return client
@track_llm_api_timing()

View file

@ -20,3 +20,23 @@ class PerplexityChatConfig(OpenAIGPTConfig):
or get_secret_str("PERPLEXITY_API_KEY")
)
return api_base, dynamic_api_key
def get_supported_openai_params(self, model: str) -> list:
"""
Perplexity supports a subset of OpenAI params
Ref: https://docs.perplexity.ai/api-reference/chat-completions
Eg. Perplexity does not support tools, tool_choice, function_call, functions, etc.
"""
return [
"frequency_penalty",
"max_tokens",
"max_completion_tokens",
"presence_penalty",
"response_format",
"stream",
"temperature",
"top_p" "max_retries",
"extra_headers",
]

View file

@ -58,7 +58,7 @@ class PetalsConfig(BaseConfig):
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -59,7 +59,7 @@ class PredibaseConfig(BaseConfig):
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -73,7 +73,7 @@ class ReplicateConfig(BaseConfig):
seed: Optional[int] = None,
debug: Optional[bool] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -47,7 +47,7 @@ class SagemakerConfig(BaseConfig):
temperature: Optional[float] = None,
return_full_text: Optional[bool] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -29,3 +29,7 @@ class TopazModelInfo(BaseLLMModelInfo):
return (
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
)
@staticmethod
def get_base_model(model: str) -> str:
return model

View file

@ -179,7 +179,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -17,7 +17,7 @@ class VertexAIAi21Config:
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -1,10 +1,10 @@
import types
from typing import Optional
import litellm
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
class VertexAILlama3Config:
class VertexAILlama3Config(OpenAIGPTConfig):
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
@ -21,7 +21,7 @@ class VertexAILlama3Config:
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
@ -46,8 +46,13 @@ class VertexAILlama3Config:
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def get_supported_openai_params(self, model: str):
supported_params = super().get_supported_openai_params(model=model)
try:
supported_params.remove("max_retries")
except KeyError:
pass
return supported_params
def map_openai_params(
self,
@ -60,7 +65,7 @@ class VertexAILlama3Config:
non_default_params["max_tokens"] = non_default_params.pop(
"max_completion_tokens"
)
return litellm.OpenAIConfig().map_openai_params(
return super().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,

View file

@ -48,7 +48,7 @@ class VertexAITextEmbeddingConfig(BaseModel):
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

Some files were not shown because too many files have changed in this diff Show more