Merge branch 'BerriAI:main' into fix-anthropic-messages-api

This commit is contained in:
Emir Ayar 2024-04-27 11:50:04 +02:00 committed by GitHub
commit 38b5f34c77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
366 changed files with 73092 additions and 56717 deletions

View file

@ -8,6 +8,11 @@ jobs:
steps:
- checkout
- run:
name: Show git commit hash
command: |
echo "Git commit hash: $CIRCLE_SHA1"
- run:
name: Check if litellm dir was updated or if pyproject.toml was modified
command: |
@ -31,16 +36,17 @@ jobs:
pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0"
pip install pyarrow
pip install "boto3>=1.28.57"
pip install "aioboto3>=12.3.0"
pip install "boto3==1.34.34"
pip install "aioboto3==12.3.0"
pip install langchain
pip install lunary==0.2.5
pip install "langfuse>=2.0.0"
pip install "langfuse==2.7.3"
pip install numpydoc
pip install traceloop-sdk==0.0.69
pip install openai
pip install prisma
pip install "httpx==0.24.1"
pip install fastapi
pip install "gunicorn==21.2.0"
pip install "anyio==3.7.1"
pip install "aiodynamo==23.10.1"
@ -51,6 +57,7 @@ jobs:
pip install "pytest-mock==3.12.0"
pip install python-multipart
pip install google-cloud-aiplatform
pip install prometheus-client==0.20.0
- save_cache:
paths:
- ./venv
@ -73,7 +80,7 @@ jobs:
name: Linting Testing
command: |
cd litellm
python -m pip install types-requests types-setuptools types-redis
python -m pip install types-requests types-setuptools types-redis types-PyYAML
if ! python -m mypy . --ignore-missing-imports; then
echo "mypy detected errors"
exit 1
@ -123,6 +130,7 @@ jobs:
build_and_test:
machine:
image: ubuntu-2204:2023.10.1
resource_class: xlarge
working_directory: ~/project
steps:
- checkout
@ -182,12 +190,19 @@ jobs:
-p 4000:4000 \
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
-e AZURE_API_KEY=$AZURE_API_KEY \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \
-e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
-e LANGFUSE_PROJECT1_SECRET=$LANGFUSE_PROJECT1_SECRET \
-e LANGFUSE_PROJECT2_SECRET=$LANGFUSE_PROJECT2_SECRET \
--name my-app \
-v $(pwd)/proxy_server_config.yaml:/app/config.yaml \
my-app:latest \
@ -292,7 +307,7 @@ jobs:
-H "Accept: application/vnd.github.v3+json" \
-H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}\"}}"
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}\", \"commit_hash\":\"$CIRCLE_SHA1\"}}"
workflows:
version: 2

View file

@ -3,11 +3,10 @@ openai
python-dotenv
tiktoken
importlib_metadata
baseten
cohere
redis
anthropic
orjson
pydantic
pydantic==1.10.14
google-cloud-aiplatform==1.43.0
redisvl==0.0.7 # semantic caching

View file

@ -1,5 +1,5 @@
/docs
/cookbook
/.circleci
/.github
/tests
docs
cookbook
.circleci
.github
tests

View file

@ -5,6 +5,13 @@ on:
inputs:
tag:
description: "The tag version you want to build"
release_type:
description: "The release type you want to build. Can be 'latest', 'stable', 'dev'"
type: string
default: "latest"
commit_hash:
description: "Commit hash"
required: true
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
env:
@ -85,9 +92,9 @@ jobs:
- name: Build and push Docker image
uses: docker/build-push-action@4976231911ebf5f32aad765192d35f942aa48cb8
with:
context: .
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
push: true
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-latest # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
labels: ${{ steps.meta.outputs.labels }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
@ -121,10 +128,10 @@ jobs:
- name: Build and push Database Docker image
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
with:
context: .
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
file: Dockerfile.database
push: true
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-latest
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
labels: ${{ steps.meta-database.outputs.labels }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
@ -158,11 +165,10 @@ jobs:
- name: Build and push Database Docker image
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
with:
context: .
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
file: ./litellm-js/spend-logs/Dockerfile
push: true
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-latest
labels: ${{ steps.meta-spend-logs.outputs.labels }}
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
build-and-push-helm-chart:
@ -236,10 +242,13 @@ jobs:
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
script: |
const commitHash = "${{ github.event.inputs.commit_hash}}";
console.log("Commit Hash:", commitHash); // Add this line for debugging
try {
const response = await github.rest.repos.createRelease({
draft: false,
generate_release_notes: true,
target_commitish: commitHash,
name: process.env.RELEASE_TAG,
owner: context.repo.owner,
prerelease: false,
@ -288,4 +297,3 @@ jobs:
}
]
}' $WEBHOOK_URL

View file

@ -77,6 +77,9 @@ if __name__ == "__main__":
new_release_body = (
existing_release_body
+ "\n\n"
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
+ "\n\n"
+ "## Load Test LiteLLM Proxy Results"
+ "\n\n"
+ markdown_table

View file

@ -10,7 +10,7 @@ class MyUser(HttpUser):
def chat_completion(self):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer sk-gUvTeN9g0sgHBMf9HeCaqA",
"Authorization": f"Bearer sk-S2-EZTUUDY0EmM6-Fy0Fyw",
# Include any additional headers you may need for authentication, etc.
}

6
.gitignore vendored
View file

@ -45,3 +45,9 @@ deploy/charts/litellm/charts/*
deploy/charts/*.tgz
litellm/proxy/vertex_key.json
**/.vim/
/node_modules
kub.yaml
loadtest_kub.yaml
litellm/proxy/_new_secret_config.yaml
litellm/proxy/_new_secret_config.yaml
litellm/proxy/_super_secret_config.yaml

View file

@ -70,5 +70,4 @@ EXPOSE 4000/tcp
ENTRYPOINT ["litellm"]
# Append "--detailed_debug" to the end of CMD to view detailed debug logs
# CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
CMD ["--port", "4000"]

View file

@ -5,7 +5,7 @@
<p align="center">Call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, etc.]
<br>
</p>
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a> | <a href="https://docs.litellm.ai/docs/hosted" target="_blank"> Hosted Proxy (Preview)</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
<h4 align="center">
<a href="https://pypi.org/project/litellm/" target="_blank">
<img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version">
@ -32,9 +32,9 @@ LiteLLM manages:
- Set Budgets & Rate limits per project, api key, model [OpenAI Proxy Server](https://docs.litellm.ai/docs/simple_proxy)
[**Jump to OpenAI Proxy Docs**](https://github.com/BerriAI/litellm?tab=readme-ov-file#openai-proxy---docs) <br>
[**Jump to Supported LLM Providers**](https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-provider-docs)
[**Jump to Supported LLM Providers**](https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-providers-docs)
🚨 **Stable Release:** v1.34.1
🚨 **Stable Release:** Use docker images with: `main-stable` tag. These run through 12 hr load tests (1k req./min).
Support for more providers. Missing a provider or LLM Platform, raise a [feature request](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+).
@ -128,7 +128,9 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
# OpenAI Proxy - ([Docs](https://docs.litellm.ai/docs/simple_proxy))
Set Budgets & Rate limits across multiple projects
Track spend + Load Balance across multiple projects
[Hosted Proxy (Preview)](https://docs.litellm.ai/docs/hosted)
The proxy provides:
@ -205,7 +207,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
| [aws - bedrock](https://docs.litellm.ai/docs/providers/bedrock) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [google - vertex_ai [Gemini]](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
| [google AI Studio - gemini](https://docs.litellm.ai/docs/providers/gemini) | ✅ | | ✅ | | |
| [google AI Studio - gemini](https://docs.litellm.ai/docs/providers/gemini) | ✅ | | ✅ | | |
| [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [cloudflare AI Workers](https://docs.litellm.ai/docs/providers/cloudflare_workers) | ✅ | ✅ | ✅ | ✅ |
| [cohere](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | ✅ | ✅ |
@ -220,7 +222,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
| [nlp_cloud](https://docs.litellm.ai/docs/providers/nlp_cloud) | ✅ | ✅ | ✅ | ✅ |
| [aleph alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ✅ | ✅ | ✅ | ✅ |
| [petals](https://docs.litellm.ai/docs/providers/petals) | ✅ | ✅ | ✅ | ✅ |
| [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ |
| [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |

204
cookbook/Proxy_Batch_Users.ipynb vendored Normal file
View file

@ -0,0 +1,204 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "680oRk1af-xJ"
},
"source": [
"# Environment Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X7TgJFn8f88p"
},
"outputs": [],
"source": [
"import csv\n",
"from typing import Optional\n",
"import httpx, json\n",
"import asyncio\n",
"\n",
"proxy_base_url = \"http://0.0.0.0:4000\" # 👈 SET TO PROXY URL\n",
"master_key = \"sk-1234\" # 👈 SET TO PROXY MASTER KEY"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rauw8EOhgBz5"
},
"outputs": [],
"source": [
"## GLOBAL HTTP CLIENT ## - faster http calls\n",
"class HTTPHandler:\n",
" def __init__(self, concurrent_limit=1000):\n",
" # Create a client with a connection pool\n",
" self.client = httpx.AsyncClient(\n",
" limits=httpx.Limits(\n",
" max_connections=concurrent_limit,\n",
" max_keepalive_connections=concurrent_limit,\n",
" )\n",
" )\n",
"\n",
" async def close(self):\n",
" # Close the client when you're done with it\n",
" await self.client.aclose()\n",
"\n",
" async def get(\n",
" self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None\n",
" ):\n",
" response = await self.client.get(url, params=params, headers=headers)\n",
" return response\n",
"\n",
" async def post(\n",
" self,\n",
" url: str,\n",
" data: Optional[dict] = None,\n",
" params: Optional[dict] = None,\n",
" headers: Optional[dict] = None,\n",
" ):\n",
" try:\n",
" response = await self.client.post(\n",
" url, data=data, params=params, headers=headers\n",
" )\n",
" return response\n",
" except Exception as e:\n",
" raise e\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7LXN8zaLgOie"
},
"source": [
"# Import Sheet\n",
"\n",
"\n",
"Format: | ID | Name | Max Budget |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oiED0usegPGf"
},
"outputs": [],
"source": [
"async def import_sheet():\n",
" tasks = []\n",
" http_client = HTTPHandler()\n",
" with open('my-batch-sheet.csv', 'r') as file:\n",
" csv_reader = csv.DictReader(file)\n",
" for row in csv_reader:\n",
" task = create_user(client=http_client, user_id=row['ID'], max_budget=row['Max Budget'], user_name=row['Name'])\n",
" tasks.append(task)\n",
" # print(f\"ID: {row['ID']}, Name: {row['Name']}, Max Budget: {row['Max Budget']}\")\n",
"\n",
" keys = await asyncio.gather(*tasks)\n",
"\n",
" with open('my-batch-sheet_new.csv', 'w', newline='') as new_file:\n",
" fieldnames = ['ID', 'Name', 'Max Budget', 'keys']\n",
" csv_writer = csv.DictWriter(new_file, fieldnames=fieldnames)\n",
" csv_writer.writeheader()\n",
"\n",
" with open('my-batch-sheet.csv', 'r') as file:\n",
" csv_reader = csv.DictReader(file)\n",
" for i, row in enumerate(csv_reader):\n",
" row['keys'] = keys[i] # Add the 'keys' value from the corresponding task result\n",
" csv_writer.writerow(row)\n",
"\n",
" await http_client.close()\n",
"\n",
"asyncio.run(import_sheet())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E7M0Li_UgJeZ"
},
"source": [
"# Create Users + Keys\n",
"\n",
"- Creates a user\n",
"- Creates a key with max budget"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NZudRFujf7j-"
},
"outputs": [],
"source": [
"\n",
"async def create_key_with_alias(client: HTTPHandler, user_id: str, max_budget: float):\n",
" global proxy_base_url\n",
" if not proxy_base_url.endswith(\"/\"):\n",
" proxy_base_url += \"/\"\n",
" url = proxy_base_url + \"key/generate\"\n",
"\n",
" # call /key/generate\n",
" print(\"CALLING /KEY/GENERATE\")\n",
" response = await client.post(\n",
" url=url,\n",
" headers={\"Authorization\": f\"Bearer {master_key}\"},\n",
" data=json.dumps({\n",
" \"user_id\": user_id,\n",
" \"key_alias\": f\"{user_id}-key\",\n",
" \"max_budget\": max_budget # 👈 KEY CHANGE: SETS MAX BUDGET PER KEY\n",
" })\n",
" )\n",
" print(f\"response: {response.text}\")\n",
" return response.json()[\"key\"]\n",
"\n",
"async def create_user(client: HTTPHandler, user_id: str, max_budget: float, user_name: str):\n",
" \"\"\"\n",
" - call /user/new\n",
" - create key for user\n",
" \"\"\"\n",
" global proxy_base_url\n",
" if not proxy_base_url.endswith(\"/\"):\n",
" proxy_base_url += \"/\"\n",
" url = proxy_base_url + \"user/new\"\n",
"\n",
" # call /user/new\n",
" await client.post(\n",
" url=url,\n",
" headers={\"Authorization\": f\"Bearer {master_key}\"},\n",
" data=json.dumps({\n",
" \"user_id\": user_id,\n",
" \"user_alias\": user_name,\n",
" \"auto_create_key\": False,\n",
" # \"max_budget\": max_budget # 👈 [OPTIONAL] Sets max budget per user (if you want to set a max budget across keys)\n",
" })\n",
" )\n",
"\n",
" # create key for user\n",
" return await create_key_with_alias(client=client, user_id=user_id, max_budget=max_budget)\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -87,6 +87,7 @@
| command-light | cohere | 0.00003 |
| command-medium-beta | cohere | 0.00003 |
| command-xlarge-beta | cohere | 0.00003 |
| command-r-plus| cohere | 0.000018 |
| j2-ultra | ai21 | 0.00003 |
| ai21.j2-ultra-v1 | bedrock | 0.0000376 |
| gpt-4-1106-preview | openai | 0.00004 |

73
cookbook/misc/config.yaml Normal file
View file

@ -0,0 +1,73 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- model_name: gpt-3.5-turbo-large
litellm_params:
model: "gpt-3.5-turbo-1106"
api_key: os.environ/OPENAI_API_KEY
rpm: 480
timeout: 300
stream_timeout: 60
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
rpm: 480
timeout: 300
stream_timeout: 60
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
input_cost_per_second: 0.000420
- model_name: text-embedding-ada-002
litellm_params:
model: azure/azure-embedding-model
api_key: os.environ/AZURE_API_KEY
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
model_info:
mode: embedding
base_model: text-embedding-ada-002
- model_name: dall-e-2
litellm_params:
model: azure/
api_version: 2023-06-01-preview
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_key: os.environ/AZURE_API_KEY
- model_name: openai-dall-e-3
litellm_params:
model: dall-e-3
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
drop_params: True
# max_budget: 100
# budget_duration: 30d
num_retries: 5
request_timeout: 600
telemetry: False
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
store_model_in_db: True
proxy_budget_rescheduler_min_time: 60
proxy_budget_rescheduler_max_time: 64
proxy_batch_write_at: 1
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
# environment_variables:
# settings for using redis caching
# REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
# REDIS_PORT: "16337"
# REDIS_PASSWORD:

View file

@ -0,0 +1,92 @@
"""
LiteLLM Migration Script!
Takes a config.yaml and calls /model/new
Inputs:
- File path to config.yaml
- Proxy base url to your hosted proxy
Step 1: Reads your config.yaml
Step 2: reads `model_list` and loops through all models
Step 3: calls `<proxy-base-url>/model/new` for each model
"""
import yaml
import requests
_in_memory_os_variables = {}
def migrate_models(config_file, proxy_base_url):
# Step 1: Read the config.yaml file
with open(config_file, "r") as f:
config = yaml.safe_load(f)
# Step 2: Read the model_list and loop through all models
model_list = config.get("model_list", [])
print("model_list: ", model_list)
for model in model_list:
model_name = model.get("model_name")
print("\nAdding model: ", model_name)
litellm_params = model.get("litellm_params", {})
api_base = litellm_params.get("api_base", "")
print("api_base on config.yaml: ", api_base)
litellm_model_name = litellm_params.get("model", "") or ""
if "vertex_ai/" in litellm_model_name:
print(f"\033[91m\nSkipping Vertex AI model\033[0m", model)
continue
for param, value in litellm_params.items():
if isinstance(value, str) and value.startswith("os.environ/"):
# check if value is in _in_memory_os_variables
if value in _in_memory_os_variables:
new_value = _in_memory_os_variables[value]
print(
"\033[92mAlready entered value for \033[0m",
value,
"\033[92musing \033[0m",
new_value,
)
else:
new_value = input(f"Enter value for {value}: ")
_in_memory_os_variables[value] = new_value
litellm_params[param] = new_value
print("\nlitellm_params: ", litellm_params)
# Confirm before sending POST request
confirm = input(
"\033[92mDo you want to send the POST request with the above parameters? (y/n): \033[0m"
)
if confirm.lower() != "y":
print("Aborting POST request.")
exit()
# Step 3: Call <proxy-base-url>/model/new for each model
url = f"{proxy_base_url}/model/new"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {master_key}",
}
data = {"model_name": model_name, "litellm_params": litellm_params}
print("POSTING data to proxy url", url)
response = requests.post(url, headers=headers, json=data)
if response.status_code != 200:
print(f"Error: {response.status_code} - {response.text}")
raise Exception(f"Error: {response.status_code} - {response.text}")
# Print the response for each model
print(
f"Response for model '{model_name}': Status Code:{response.status_code} - {response.text}"
)
# Usage
config_file = "config.yaml"
proxy_base_url = "http://0.0.0.0:4000"
master_key = "sk-1234"
print(f"config_file: {config_file}")
print(f"proxy_base_url: {proxy_base_url}")
migrate_models(config_file, proxy_base_url)

View file

@ -15,15 +15,16 @@ spec:
containers:
- name: litellm-container
image: ghcr.io/berriai/litellm:main-latest
imagePullPolicy: Always
env:
- name: AZURE_API_KEY
value: "d6f****"
- name: AZURE_API_BASE
value: "https://openai
value: "https://openai"
- name: LITELLM_MASTER_KEY
value: "sk-1234"
- name: DATABASE_URL
value: "postgresql://ishaan:*********""
value: "postgresql://ishaan*********"
args:
- "--config"
- "/app/proxy_config.yaml" # Update the path to mount the config file

View file

@ -1,10 +1,16 @@
version: "3.9"
services:
litellm:
build:
context: .
args:
target: runtime
image: ghcr.io/berriai/litellm:main-latest
volumes:
- ./proxy_server_config.yaml:/app/proxy_server_config.yaml # mount your litellm config.yaml
ports:
- "4000:4000"
environment:
- AZURE_API_KEY=sk-123
- "4000:4000" # Map the container port to the host, change the host port if necessary
volumes:
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8" ]
# ...rest of your docker-compose config if any

View file

@ -72,7 +72,7 @@ Here's the code for how we format all providers. Let us know how we can improve
| Anthropic | `claude-instant-1`, `claude-instant-1.2`, `claude-2` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/anthropic.py#L84)
| OpenAI Text Completion | `text-davinci-003`, `text-curie-001`, `text-babbage-001`, `text-ada-001`, `babbage-002`, `davinci-002`, | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/main.py#L442)
| Replicate | all model names starting with `replicate/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/replicate.py#L180)
| Cohere | `command-nightly`, `command`, `command-light`, `command-medium-beta`, `command-xlarge-beta` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/cohere.py#L115)
| Cohere | `command-nightly`, `command`, `command-light`, `command-medium-beta`, `command-xlarge-beta`, `command-r-plus` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/cohere.py#L115)
| Huggingface | all model names starting with `huggingface/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/huggingface_restapi.py#L186)
| OpenRouter | all model names starting with `openrouter/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/main.py#L611)
| AI21 | `j2-mid`, `j2-light`, `j2-ultra` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/ai21.py#L107)

View file

@ -0,0 +1,45 @@
# Using Vision Models
## Quick Start
Example passing images to a model
```python
import os
from litellm import completion
os.environ["OPENAI_API_KEY"] = "your-api-key"
# openai call
response = completion(
model = "gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
)
```
## Checking if a model supports `vision`
Use `litellm.supports_vision(model="")` -> returns `True` if model supports `vision` and `False` if not
```python
assert litellm.supports_vision(model="gpt-4-vision-preview") == True
assert litellm.supports_vision(model="gemini-1.0-pro-visionn") == True
assert litellm.supports_vision(model="gpt-3.5-turbo") == False
```

View file

@ -339,6 +339,8 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
| textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` |
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
## Voyage AI Embedding Models

View file

@ -1,5 +1,5 @@
# Enterprise
For companies that need better security, user management and professional support
For companies that need SSO, user management and professional support for LiteLLM Proxy
:::info
@ -8,12 +8,13 @@ For companies that need better security, user management and professional suppor
:::
This covers:
- ✅ **Features under the [LiteLLM Commercial License](https://docs.litellm.ai/docs/proxy/enterprise):**
- ✅ **Features under the [LiteLLM Commercial License (Content Mod, Custom Tags, etc.)](https://docs.litellm.ai/docs/proxy/enterprise)**
- ✅ **Feature Prioritization**
- ✅ **Custom Integrations**
- ✅ **Professional Support - Dedicated discord + slack**
- ✅ **Custom SLAs**
- ✅ **Secure access with Single Sign-On**
- ✅ [**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui)
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
## Frequently Asked Questions

View file

@ -0,0 +1,49 @@
import Image from '@theme/IdealImage';
# Hosted LiteLLM Proxy
LiteLLM maintains the proxy, so you can focus on your core products.
## [**Get Onboarded**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
This is in alpha. Schedule a call with us, and we'll give you a hosted proxy within 30 minutes.
[**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
### **Status**: Alpha
Our proxy is already used in production by customers.
See our status page for [**live reliability**](https://status.litellm.ai/)
### **Benefits**
- **No Maintenance, No Infra**: We'll maintain the proxy, and spin up any additional infrastructure (e.g.: separate server for spend logs) to make sure you can load balance + track spend across multiple LLM projects.
- **Reliable**: Our hosted proxy is tested on 1k requests per second, making it reliable for high load.
- **Secure**: LiteLLM is currently undergoing SOC-2 compliance, to make sure your data is as secure as possible.
### Pricing
Pricing is based on usage. We can figure out a price that works for your team, on the call.
[**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
## **Screenshots**
### 1. Create keys
<Image img={require('../img/litellm_hosted_ui_create_key.png')} />
### 2. Add Models
<Image img={require('../img/litellm_hosted_ui_add_models.png')}/>
### 3. Track spend
<Image img={require('../img/litellm_hosted_usage_dashboard.png')} />
### 4. Configure load balancing
<Image img={require('../img/litellm_hosted_ui_router.png')} />
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)

View file

@ -8,6 +8,7 @@ liteLLM supports:
- [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback)
- [Lunary](https://lunary.ai/docs)
- [Langfuse](https://langfuse.com/docs)
- [Helicone](https://docs.helicone.ai/introduction)
- [Traceloop](https://traceloop.com/docs)
- [Athina](https://docs.athina.ai/)
@ -22,8 +23,8 @@ from litellm import completion
# set callbacks
litellm.input_callback=["sentry"] # for sentry breadcrumbing - logs the input being sent to the api
litellm.success_callback=["posthog", "helicone", "lunary", "athina"]
litellm.failure_callback=["sentry", "lunary"]
litellm.success_callback=["posthog", "helicone", "langfuse", "lunary", "athina"]
litellm.failure_callback=["sentry", "lunary", "langfuse"]
## set env variables
os.environ['SENTRY_DSN'], os.environ['SENTRY_API_TRACE_RATE']= ""
@ -32,6 +33,9 @@ os.environ["HELICONE_API_KEY"] = ""
os.environ["TRACELOOP_API_KEY"] = ""
os.environ["LUNARY_PUBLIC_KEY"] = ""
os.environ["ATHINA_API_KEY"] = ""
os.environ["LANGFUSE_PUBLIC_KEY"] = ""
os.environ["LANGFUSE_SECRET_KEY"] = ""
os.environ["LANGFUSE_HOST"] = ""
response = completion(model="gpt-3.5-turbo", messages=messages)
```

View file

@ -0,0 +1,68 @@
# Greenscale Tutorial
[Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII).
## Getting Started
Use Greenscale to log requests across all LLM Providers
liteLLM provides `callbacks`, making it easy for you to log data depending on the status of your responses.
## Using Callbacks
First, email `hello@greenscale.ai` to get an API_KEY.
Use just 1 line of code, to instantly log your responses **across all providers** with Greenscale:
```python
litellm.success_callback = ["greenscale"]
```
### Complete code
```python
from litellm import completion
## set env variables
os.environ['GREENSCALE_API_KEY'] = 'your-greenscale-api-key'
os.environ['GREENSCALE_ENDPOINT'] = 'greenscale-endpoint'
os.environ["OPENAI_API_KEY"]= ""
# set callback
litellm.success_callback = ["greenscale"]
#openai call
response = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}]
metadata={
"greenscale_project": "acme-project",
"greenscale_application": "acme-application"
}
)
```
## Additional information in metadata
You can send any additional information to Greenscale by using the `metadata` field in completion and `greenscale_` prefix. This can be useful for sending metadata about the request, such as the project and application name, customer_id, enviornment, or any other information you want to track usage. `greenscale_project` and `greenscale_application` are required fields.
```python
#openai call with additional metadata
response = completion(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Hi 👋 - i'm openai"}
],
metadata={
"greenscale_project": "acme-project",
"greenscale_application": "acme-application",
"greenscale_customer_id": "customer-123"
}
)
```
## Support & Talk with Greenscale Team
- [Schedule Demo 👋](https://calendly.com/nandesh/greenscale)
- [Website 💻](https://greenscale.ai)
- Our email ✉️ `hello@greenscale.ai`

View file

@ -57,7 +57,7 @@ os.environ["LANGSMITH_API_KEY"] = ""
os.environ['OPENAI_API_KEY']=""
# set langfuse as a callback, litellm will send the data to langfuse
litellm.success_callback = ["langfuse"]
litellm.success_callback = ["langsmith"]
response = litellm.completion(
model="gpt-3.5-turbo",

View file

@ -177,11 +177,7 @@ print(response)
:::info
Claude returns it's output as an XML Tree. [Here is how we translate it](https://github.com/BerriAI/litellm/blob/49642a5b00a53b1babc1a753426a8afcac85dbbe/litellm/llms/prompt_templates/factory.py#L734).
You can see the raw response via `response._hidden_params["original_response"]`.
Claude hallucinates, e.g. returning the list param `value` as `<value>\n<item>apple</item>\n<item>banana</item>\n</value>` or `<value>\n<list>\n<item>apple</item>\n<item>banana</item>\n</list>\n</value>`.
LiteLLM now uses Anthropic's 'tool' param 🎉 (v1.34.29+)
:::
```python
@ -228,6 +224,91 @@ assert isinstance(
```
### Parallel Function Calling
Here's how to pass the result of a function call back to an anthropic model:
```python
from litellm import completion
import os
os.environ["ANTHROPIC_API_KEY"] = "sk-ant.."
litellm.set_verbose = True
### 1ST FUNCTION CALL ###
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
### 2ND FUNCTION CALL ###
# In the second response, Claude should deduce answer from tool results
second_response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
)
print(second_response)
except Exception as e:
print(f"An error occurred - {str(e)}")
```
s/o @[Shekhar Patnaik](https://www.linkedin.com/in/patnaikshekhar) for requesting this!
## Usage - Vision
```python

View file

@ -1,55 +1,215 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Azure AI Studio
## Using Mistral models deployed on Azure AI Studio
### Sample Usage - setting env vars
Set `MISTRAL_AZURE_API_KEY` and `MISTRAL_AZURE_API_BASE` in your env
```shell
MISTRAL_AZURE_API_KEY = "zE************""
MISTRAL_AZURE_API_BASE = "https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1"
**Ensure the following:**
1. The API Base passed ends in the `/v1/` prefix
example:
```python
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
```
2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
response = completion(
model="mistral/Mistral-large-dfgfj",
messages=[
{"role": "user", "content": "hello from litellm"}
],
import litellm
response = litellm.completion(
model="azure/command-r-plus",
api_base="<your-deployment-base>/v1/"
api_key="eskk******"
messages=[{"role": "user", "content": "What is the meaning of life?"}],
)
print(response)
```
### Sample Usage - passing `api_base` and `api_key` to `litellm.completion`
```python
from litellm import completion
import os
</TabItem>
<TabItem value="proxy" label="PROXY">
response = completion(
model="mistral/Mistral-large-dfgfj",
api_base="https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com",
api_key = "JGbKodRcTp****"
messages=[
{"role": "user", "content": "hello from litellm"}
],
)
print(response)
```
## Sample Usage - LiteLLM Proxy
### [LiteLLM Proxy] Using Mistral Models
1. Add models to your config.yaml
Set this on your litellm proxy config.yaml
```yaml
model_list:
- model_name: mistral
litellm_params:
model: mistral/Mistral-large-dfgfj
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com
model: azure/mistral-large-latest
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
api_key: JGbKodRcTp****
- model_name: command-r-plus
litellm_params:
model: azure/command-r-plus
api_key: os.environ/AZURE_COHERE_API_KEY
api_base: os.environ/AZURE_COHERE_API_BASE
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="mistral",
messages = [
{
"role": "user",
"content": "what llm are you"
}
],
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "mistral",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Function Calling
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
# set env
os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key"
os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="azure/mistral-large-latest",
api_base=os.getenv("AZURE_MISTRAL_API_BASE")
api_key=os.getenv("AZURE_MISTRAL_API_KEY")
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $YOUR_API_KEY" \
-d '{
"model": "mistral",
"messages": [
{
"role": "user",
"content": "What'\''s the weather like in Boston today?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
],
"tool_choice": "auto"
}'
```
</TabItem>
</Tabs>
## Supported Models
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
| Cohere ommand-r | `completion(model="azure/command-r", messages)` |
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |

View file

@ -47,6 +47,7 @@ for chunk in response:
|------------|----------------|
| command-r | `completion('command-r', messages)` |
| command-light | `completion('command-light', messages)` |
| command-r-plus | `completion('command-r-plus', messages)` |
| command-medium | `completion('command-medium', messages)` |
| command-medium-beta | `completion('command-medium-beta', messages)` |
| command-xlarge-nightly | `completion('command-xlarge-nightly', messages)` |

View file

@ -23,7 +23,7 @@ In certain use-cases you may need to make calls to the models and pass [safety s
```python
response = completion(
model="gemini/gemini-pro",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}],
safety_settings=[
{
"category": "HARM_CATEGORY_HARASSMENT",
@ -96,8 +96,7 @@ print(content)
## Chat Models
| Model Name | Function Call | Required OS Variables |
|------------------|--------------------------------------|-------------------------|
|-----------------------|--------------------------------------------------------|--------------------------------|
| gemini-pro | `completion('gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro | `completion('gemini/gemini-1.5-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro-latest | `completion('gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-pro-vision | `completion('gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro-vision | `completion('gemini/gemini-1.5-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |

View file

@ -48,6 +48,109 @@ We support ALL Groq models, just set `groq/` as a prefix when sending completion
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| llama3-8b-8192 | `completion(model="groq/llama3-8b-8192", messages)` |
| llama3-70b-8192 | `completion(model="groq/llama3-70b-8192", messages)` |
| llama2-70b-4096 | `completion(model="groq/llama2-70b-4096", messages)` |
| mixtral-8x7b-32768 | `completion(model="groq/mixtral-8x7b-32768", messages)` |
| gemma-7b-it | `completion(model="groq/gemma-7b-it", messages)` |
## Groq - Tool / Function Calling Example
```python
# Example dummy function hard coded to return the current weather
import json
def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
elif "san francisco" in location.lower():
return json.dumps(
{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}
)
elif "paris" in location.lower():
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
else:
return json.dumps({"location": location, "temperature": "unknown"})
# Step 1: send the conversation and available functions to the model
messages = [
{
"role": "system",
"content": "You are a function calling LLM that uses the data extracted from get_current_weather to answer questions about the weather in San Francisco.",
},
{
"role": "user",
"content": "What's the weather like in San Francisco?",
},
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
response = litellm.completion(
model="groq/llama2-70b-4096",
messages=messages,
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
)
print("Response\n", response)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
# Step 2: check if the model wanted to call a function
if tool_calls:
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = {
"get_current_weather": get_current_weather,
}
messages.append(
response_message
) # extend conversation with assistant's reply
print("Response message\n", response_message)
# Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(
location=function_args.get("location"),
unit=function_args.get("unit"),
)
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
print(f"messages: {messages}")
second_response = litellm.completion(
model="groq/llama2-70b-4096", messages=messages
) # get a new response from the model where it can see the function response
print("second response\n", second_response)
```

View file

@ -50,8 +50,53 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported.
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
| mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` |
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
## Function Calling
```python
from litellm import completion
# set env
os.environ["MISTRAL_API_KEY"] = "your-api-key"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
```
## Sample Usage - Embedding
```python
from litellm import embedding

View file

@ -1,5 +1,5 @@
# Ollama
LiteLLM supports all models from [Ollama](https://github.com/jmorganca/ollama)
LiteLLM supports all models from [Ollama](https://github.com/ollama/ollama)
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_Ollama.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
@ -97,7 +97,7 @@ response = completion(
print(response)
```
## Ollama Models
Ollama supported models: https://github.com/jmorganca/ollama
Ollama supported models: https://github.com/ollama/ollama
| Model Name | Function Call |
|----------------------|-----------------------------------------------------------------------------------

View file

@ -1,5 +1,8 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# OpenAI
LiteLLM supports OpenAI Chat + Text completion and embedding calls.
LiteLLM supports OpenAI Chat + Embedding calls.
### Required API Keys
@ -22,6 +25,132 @@ response = completion(
)
```
### Usage - LiteLLM Proxy Server
Here's how to call OpenAI models with the LiteLLM Proxy Server
### 1. Save key in your environment
```bash
export OPENAI_API_KEY=""
```
### 2. Start the proxy
<Tabs>
<TabItem value="config" label="config.yaml">
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/gpt-3.5-turbo # The `openai/` prefix will call openai.chat.completions.create
api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-3.5-turbo-instruct
litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct # The `text-completion-openai/` prefix will call openai.completions.create
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
<TabItem value="config-*" label="config.yaml - proxy all OpenAI models">
Use this to add all openai models with one API Key. **WARNING: This will not do any load balancing**
This means requests to `gpt-4`, `gpt-3.5-turbo` , `gpt-4-turbo-preview` will all go through this route
```yaml
model_list:
- model_name: "*" # all requests where model not in your config go to this deployment
litellm_params:
model: openai/* # set `openai/` to use the openai route
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
<TabItem value="cli" label="CLI">
```bash
$ litellm --model gpt-3.5-turbo
# Server running on http://0.0.0.0:4000
```
</TabItem>
</Tabs>
### 3. Test it
<Tabs>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
model = "gpt-3.5-turbo",
temperature=0.1
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
### Optional Keys - OpenAI Organization, OpenAI API Base
```python
@ -34,6 +163,8 @@ os.environ["OPENAI_API_BASE"] = "openaiai-api-base" # OPTIONAL
| Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------|
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
| gpt-4-1106-preview | `response = completion(model="gpt-4-1106-preview", messages=messages)` |
| gpt-3.5-turbo-1106 | `response = completion(model="gpt-3.5-turbo-1106", messages=messages)` |
@ -55,6 +186,7 @@ These also support the `OPENAI_API_BASE` environment variable, which can be used
## OpenAI Vision Models
| Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------|
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` |
#### Usage
@ -88,19 +220,6 @@ response = completion(
```
## OpenAI Text Completion Models / Instruct Models
| Model Name | Function Call |
|---------------------|----------------------------------------------------|
| gpt-3.5-turbo-instruct | `response = completion(model="gpt-3.5-turbo-instruct", messages=messages)` |
| gpt-3.5-turbo-instruct-0914 | `response = completion(model="gpt-3.5-turbo-instruct-091", messages=messages)` |
| text-davinci-003 | `response = completion(model="text-davinci-003", messages=messages)` |
| ada-001 | `response = completion(model="ada-001", messages=messages)` |
| curie-001 | `response = completion(model="curie-001", messages=messages)` |
| babbage-001 | `response = completion(model="babbage-001", messages=messages)` |
| babbage-002 | `response = completion(model="babbage-002", messages=messages)` |
| davinci-002 | `response = completion(model="davinci-002", messages=messages)` |
## Advanced
### Parallel Function calling

View file

@ -5,7 +5,9 @@ import TabItem from '@theme/TabItem';
To call models hosted behind an openai proxy, make 2 changes:
1. Put `openai/` in front of your model name, so litellm knows you're trying to call an openai-compatible endpoint.
1. For `/chat/completions`: Put `openai/` in front of your model name, so litellm knows you're trying to call an openai `/chat/completions` endpoint.
2. For `/completions`: Put `text-completion-openai/` in front of your model name, so litellm knows you're trying to call an openai `/completions` endpoint.
2. **Do NOT** add anything additional to the base url e.g. `/v1/embedding`. LiteLLM uses the openai-client to make these calls, and that automatically adds the relevant endpoints.

View file

@ -1,7 +1,16 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Replicate
LiteLLM supports all models on Replicate
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### API KEYS
```python
import os
@ -16,14 +25,175 @@ import os
## set ENV variables
os.environ["REPLICATE_API_KEY"] = "replicate key"
# replicate llama-2 call
# replicate llama-3 call
response = completion(
model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
model="replicate/meta/meta-llama-3-8b-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```
### Example - Calling Replicate Deployments
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: replicate/meta/meta-llama-3-8b-instruct
api_key: os.environ/REPLICATE_API_KEY
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="llama-3",
messages = [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
]
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama-3",
"messages": [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
],
}'
```
</TabItem>
</Tabs>
### Expected Replicate Call
This is the call litellm will make to replicate, from the above example:
```bash
POST Request Sent from LiteLLM:
curl -X POST \
https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct \
-H 'Authorization: Token your-api-key' -H 'Content-Type: application/json' \
-d '{'version': 'meta/meta-llama-3-8b-instruct', 'input': {'prompt': '<|start_header_id|>system<|end_header_id|>\n\nBe a good human!<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat do you know about earth?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'}}'
```
</TabItem>
</Tabs>
## Advanced Usage - Prompt Formatting
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
To apply a custom prompt template:
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import litellm
import os
os.environ["REPLICATE_API_KEY"] = ""
# Create your own custom prompt template
litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K",
initial_prompt_value="You are a good assistant" # [OPTIONAL]
roles={
"system": {
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
},
"user": {
"pre_message": "[INST] ", # [OPTIONAL]
"post_message": " [/INST]" # [OPTIONAL]
},
"assistant": {
"pre_message": "\n" # [OPTIONAL]
"post_message": "\n" # [OPTIONAL]
}
}
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
)
def test_replicate_custom_model():
model = "replicate/togethercomputer/LLaMA-2-7B-32K"
response = completion(model=model, messages=messages)
print(response['choices'][0]['message']['content'])
return response
test_replicate_custom_model()
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
# Model-specific parameters
model_list:
- model_name: mistral-7b # model alias
litellm_params: # actual params for litellm.completion()
model: "replicate/mistralai/Mistral-7B-Instruct-v0.1"
api_key: os.environ/REPLICATE_API_KEY
initial_prompt_value: "\n"
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
final_prompt_value: "\n"
bos_token: "<s>"
eos_token: "</s>"
max_tokens: 4096
```
</TabItem>
</Tabs>
## Advanced Usage - Calling Replicate Deployments
Calling a [deployed replicate LLM](https://replicate.com/deployments)
Add the `replicate/deployments/` prefix to your model, so litellm will call the `deployments` endpoint. This will call `ishaan-jaff/ishaan-mistral` deployment on replicate
@ -40,7 +210,7 @@ Replicate responses can take 3-5 mins due to replicate cold boots, if you're try
:::
### Replicate Models
## Replicate Models
liteLLM supports all replicate LLMs
For replicate models ensure to add a `replicate/` prefix to the `model` arg. liteLLM detects it using this arg.
@ -49,15 +219,15 @@ Below are examples on how to call replicate LLMs using liteLLM
Model Name | Function Call | Required OS Variables |
-----------------------------|----------------------------------------------------------------|--------------------------------------|
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` |
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` |
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` |
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` |
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |
replicate deployment | `completion(model='replicate/deployments/ishaan-jaff/ishaan-mistral', messages)` | `os.environ['REPLICATE_API_KEY']` |
### Passing additional params - max_tokens, temperature
## Passing additional params - max_tokens, temperature
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
```python
@ -73,11 +243,22 @@ response = completion(
messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20,
temperature=0.5
)
```
### Passings Replicate specific params
**proxy**
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: replicate/meta/meta-llama-3-8b-instruct
api_key: os.environ/REPLICATE_API_KEY
max_tokens: 20
temperature: 0.5
```
## Passings Replicate specific params
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Replicate by passing them to `litellm.completion`
Example `seed`, `min_tokens` are Replicate specific param
@ -98,3 +279,15 @@ response = completion(
top_k=20,
)
```
**proxy**
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: replicate/meta/meta-llama-3-8b-instruct
api_key: os.environ/REPLICATE_API_KEY
min_tokens: 2
top_k: 20
```

View file

@ -0,0 +1,163 @@
# OpenAI (Text Completion)
LiteLLM supports OpenAI text completion models
### Required API Keys
```python
import os
os.environ["OPENAI_API_KEY"] = "your-api-key"
```
### Usage
```python
import os
from litellm import completion
os.environ["OPENAI_API_KEY"] = "your-api-key"
# openai call
response = completion(
model = "gpt-3.5-turbo-instruct",
messages=[{ "content": "Hello, how are you?","role": "user"}]
)
```
### Usage - LiteLLM Proxy Server
Here's how to call OpenAI models with the LiteLLM Proxy Server
### 1. Save key in your environment
```bash
export OPENAI_API_KEY=""
```
### 2. Start the proxy
<Tabs>
<TabItem value="config" label="config.yaml">
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/gpt-3.5-turbo # The `openai/` prefix will call openai.chat.completions.create
api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-3.5-turbo-instruct
litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct # The `text-completion-openai/` prefix will call openai.completions.create
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
<TabItem value="config-*" label="config.yaml - proxy all OpenAI models">
Use this to add all openai models with one API Key. **WARNING: This will not do any load balancing**
This means requests to `gpt-4`, `gpt-3.5-turbo` , `gpt-4-turbo-preview` will all go through this route
```yaml
model_list:
- model_name: "*" # all requests where model not in your config go to this deployment
litellm_params:
model: openai/* # set `openai/` to use the openai route
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
<TabItem value="cli" label="CLI">
```bash
$ litellm --model gpt-3.5-turbo-instruct
# Server running on http://0.0.0.0:4000
```
</TabItem>
</Tabs>
### 3. Test it
<Tabs>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo-instruct",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo-instruct", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
model = "gpt-3.5-turbo-instruct",
temperature=0.1
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
## OpenAI Text Completion Models / Instruct Models
| Model Name | Function Call |
|---------------------|----------------------------------------------------|
| gpt-3.5-turbo-instruct | `response = completion(model="gpt-3.5-turbo-instruct", messages=messages)` |
| gpt-3.5-turbo-instruct-0914 | `response = completion(model="gpt-3.5-turbo-instruct-0914", messages=messages)` |
| text-davinci-003 | `response = completion(model="text-davinci-003", messages=messages)` |
| ada-001 | `response = completion(model="ada-001", messages=messages)` |
| curie-001 | `response = completion(model="curie-001", messages=messages)` |
| babbage-001 | `response = completion(model="babbage-001", messages=messages)` |
| babbage-002 | `response = completion(model="babbage-002", messages=messages)` |
| davinci-002 | `response = completion(model="davinci-002", messages=messages)` |

View file

@ -1,18 +1,25 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# VertexAI - Google [Gemini, Model Garden]
# VertexAI [Anthropic, Gemini, Model Garden]
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
## Pre-requisites
* `pip install google-cloud-aiplatform`
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
* Authentication:
* run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc)
* Alternatively you can set `application_default_credentials.json`
* Alternatively you can set `GOOGLE_APPLICATION_CREDENTIALS`
Here's how: [**Jump to Code**](#extra)
- Create a service account on GCP
- Export the credentials as a json
- load the json and json.dump the json as a string
- store the json string in your environment as `GOOGLE_APPLICATION_CREDENTIALS`
## Sample Usage
```python
@ -123,6 +130,100 @@ Here's how to use Vertex AI with the LiteLLM Proxy Server
</Tabs>
## Specifying Safety Settings
In certain use-cases you may need to make calls to the models and pass [safety settigns](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example:
<Tabs>
<TabItem value="sdk" label="SDK">
```python
response = completion(
model="gemini/gemini-pro",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
safety_settings=[
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**Option 1: Set in config**
```yaml
model_list:
- model_name: gemini-experimental
litellm_params:
model: vertex_ai/gemini-experimental
vertex_project: litellm-epic
vertex_location: us-central1
safety_settings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_NONE
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_NONE
```
**Option 2: Set on call**
```python
response = client.chat.completions.create(
model="gemini-experimental",
messages=[
{
"role": "user",
"content": "Can you write exploits?",
}
],
max_tokens=8192,
stream=False,
temperature=0.0,
extra_body={
"safety_settings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
],
}
)
```
</TabItem>
</Tabs>
## Set Vertex Project & Vertex Location
All calls using Vertex AI require the following parameters:
* Your Project ID
@ -149,6 +250,85 @@ os.environ["VERTEXAI_LOCATION"] = "us-central1 # Your Location
# set directly on module
litellm.vertex_location = "us-central1 # Your Location
```
## Anthropic
| Model Name | Function Call |
|------------------|--------------------------------------|
| claude-3-opus@20240229 | `completion('vertex_ai/claude-3-opus@20240229', messages)` |
| claude-3-sonnet@20240229 | `completion('vertex_ai/claude-3-sonnet@20240229', messages)` |
| claude-3-haiku@20240307 | `completion('vertex_ai/claude-3-haiku@20240307', messages)` |
### Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""
model = "claude-3-sonnet@20240229"
vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"]
vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"]
response = completion(
model="vertex_ai/" + model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
)
print("\nModel Response", response)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: anthropic-vertex
litellm_params:
model: vertex_ai/claude-3-sonnet@20240229
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-east-1"
- model_name: anthropic-vertex
litellm_params:
model: vertex_ai/claude-3-sonnet@20240229
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-west-1"
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "anthropic-vertex", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
## Model Garden
| Model Name | Function Call |
|------------------|--------------------------------------|
@ -175,18 +355,15 @@ response = completion(
|------------------|--------------------------------------|
| gemini-pro | `completion('gemini-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
| Model Name | Function Call |
|------------------|--------------------------------------|
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
## Gemini Pro Vision
| Model Name | Function Call |
|------------------|--------------------------------------|
| gemini-pro-vision | `completion('gemini-pro-vision', messages)`, `completion('vertex_ai/gemini-pro-vision', messages)`|
## Gemini 1.5 Pro (and Vision)
| Model Name | Function Call |
|------------------|--------------------------------------|
| gemini-1.5-pro-vision | `completion('gemini-pro-vision', messages)`, `completion('vertex_ai/gemini-pro-vision', messages)`|
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
@ -298,3 +475,75 @@ print(response)
| code-bison@001 | `completion('code-bison@001', messages)` |
| code-gecko@001 | `completion('code-gecko@001', messages)` |
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
## Extra
### Using `GOOGLE_APPLICATION_CREDENTIALS`
Here's the code for storing your service account credentials as `GOOGLE_APPLICATION_CREDENTIALS` environment variable:
```python
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary file
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
```
### Using GCP Service Account
1. Figure out the Service Account bound to the Google Cloud Run service
<Image img={require('../../img/gcp_acc_1.png')} />
2. Get the FULL EMAIL address of the corresponding Service Account
3. Next, go to IAM & Admin > Manage Resources , select your top-level project that houses your Google Cloud Run Service
Click `Add Principal`
<Image img={require('../../img/gcp_acc_2.png')}/>
4. Specify the Service Account as the principal and Vertex AI User as the role
<Image img={require('../../img/gcp_acc_3.png')}/>
Once that's done, when you deploy the new container in the Google Cloud Run service, LiteLLM will have automatic access to all Vertex AI endpoints.
s/o @[Darien Kindlund](https://www.linkedin.com/in/kindlund/) for this tutorial

View file

@ -25,8 +25,11 @@ All models listed here https://docs.voyageai.com/embeddings/#models-and-specific
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| voyage-2 | `embedding(model="voyage/voyage-2", input)` |
| voyage-large-2 | `embedding(model="voyage/voyage-large-2", input)` |
| voyage-law-2 | `embedding(model="voyage/voyage-law-2", input)` |
| voyage-code-2 | `embedding(model="voyage/voyage-code-2", input)` |
| voyage-lite-02-instruct | `embedding(model="voyage/voyage-lite-02-instruct", input)` |
| voyage-01 | `embedding(model="voyage/voyage-01", input)` |
| voyage-lite-01 | `embedding(model="voyage/voyage-lite-01", input)` |
| voyage-lite-01-instruct | `embedding(model="voyage/voyage-lite-01-instruct", input)` |

View file

@ -61,6 +61,22 @@ litellm_settings:
ttl: 600 # will be cached on redis for 600s
```
## SSL
just set `REDIS_SSL="True"` in your .env, and LiteLLM will pick this up.
```env
REDIS_SSL="True"
```
For quick testing, you can also use REDIS_URL, eg.:
```
REDIS_URL="rediss://.."
```
but we **don't** recommend using REDIS_URL in prod. We've noticed a performance difference between using it vs. redis_host, port, etc.
#### Step 2: Add Redis Credentials to .env
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
@ -265,32 +281,6 @@ litellm_settings:
supported_call_types: ["acompletion", "completion", "embedding", "aembedding"] # defaults to all litellm call types
```
### Turn on `batch_redis_requests`
**What it does?**
When a request is made:
- Check if a key starting with `litellm:<hashed_api_key>:<call_type>:` exists in-memory, if no - get the last 100 cached requests for this key and store it
- New requests are stored with this `litellm:..` as the namespace
**Why?**
Reduce number of redis GET requests. This improved latency by 46% in prod load tests.
**Usage**
```yaml
litellm_settings:
cache: true
cache_params:
type: redis
... # remaining redis args (host, port, etc.)
callbacks: ["batch_redis_requests"] # 👈 KEY CHANGE!
```
[**SEE CODE**](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/batch_redis_get.py)
### Turn on / off caching per request.
The proxy support 3 cache-controls:
@ -384,6 +374,87 @@ chat_completion = client.chat.completions.create(
)
```
### Deleting Cache Keys - `/cache/delete`
In order to delete a cache key, send a request to `/cache/delete` with the `keys` you want to delete
Example
```shell
curl -X POST "http://0.0.0.0:4000/cache/delete" \
-H "Authorization: Bearer sk-1234" \
-d '{"keys": ["586bf3f3c1bf5aecb55bd9996494d3bbc69eb58397163add6d49537762a7548d", "key2"]}'
```
```shell
# {"status":"success"}
```
#### Viewing Cache Keys from responses
You can view the cache_key in the response headers, on cache hits the cache key is sent as the `x-litellm-cache-key` response headers
```shell
curl -i --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"user": "ishan",
"messages": [
{
"role": "user",
"content": "what is litellm"
}
],
}'
```
Response from litellm proxy
```json
date: Thu, 04 Apr 2024 17:37:21 GMT
content-type: application/json
x-litellm-cache-key: 586bf3f3c1bf5aecb55bd9996494d3bbc69eb58397163add6d49537762a7548d
{
"id": "chatcmpl-9ALJTzsBlXR9zTxPvzfFFtFbFtG6T",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "I'm sorr.."
"role": "assistant"
}
}
],
"created": 1712252235,
}
```
### Turn on `batch_redis_requests`
**What it does?**
When a request is made:
- Check if a key starting with `litellm:<hashed_api_key>:<call_type>:` exists in-memory, if no - get the last 100 cached requests for this key and store it
- New requests are stored with this `litellm:..` as the namespace
**Why?**
Reduce number of redis GET requests. This improved latency by 46% in prod load tests.
**Usage**
```yaml
litellm_settings:
cache: true
cache_params:
type: redis
... # remaining redis args (host, port, etc.)
callbacks: ["batch_redis_requests"] # 👈 KEY CHANGE!
```
[**SEE CODE**](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/batch_redis_get.py)
## Supported `cache_params` on proxy config.yaml
```yaml

View file

@ -600,6 +600,7 @@ general_settings:
"general_settings": {
"completion_model": "string",
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
"enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param

View file

@ -0,0 +1,9 @@
# 🎉 Demo App
Here is a demo of the proxy. To log in pass in:
- Username: admin
- Password: sk-1234
[Demo UI](https://demo.litellm.ai/ui)

View file

@ -231,13 +231,16 @@ Your OpenAI proxy server is now running on `http://127.0.0.1:4000`.
| Docs | When to Use |
| --- | --- |
| [Quick Start](#quick-start) | call 100+ LLMs + Load Balancing |
| [Deploy with Database](#deploy-with-database) | + use Virtual Keys + Track Spend |
| [Deploy with Database](#deploy-with-database) | + use Virtual Keys + Track Spend (Note: When deploying with a database providing a `DATABASE_URL` and `LITELLM_MASTER_KEY` are required in your env ) |
| [LiteLLM container + Redis](#litellm-container--redis) | + load balance across multiple litellm containers |
| [LiteLLM Database container + PostgresDB + Redis](#litellm-database-container--postgresdb--redis) | + use Virtual Keys + Track Spend + load balance across multiple litellm containers |
## Deploy with Database
### Docker, Kubernetes, Helm Chart
Requirements:
- Need a postgres database (e.g. [Supabase](https://supabase.com/), [Neon](https://neon.tech/), etc) Set `DATABASE_URL=postgresql://<user>:<password>@<host>:<port>/<dbname>` in your env
- Set a `LITELLM_MASTER_KEY`, this is your Proxy Admin key - you can use this to create other keys (🚨 must start with `sk-`)
<Tabs>
@ -246,12 +249,14 @@ Your OpenAI proxy server is now running on `http://127.0.0.1:4000`.
We maintain a [seperate Dockerfile](https://github.com/BerriAI/litellm/pkgs/container/litellm-database) for reducing build time when running LiteLLM proxy with a connected Postgres Database
```shell
docker pull docker pull ghcr.io/berriai/litellm-database:main-latest
docker pull ghcr.io/berriai/litellm-database:main-latest
```
```shell
docker run \
-v $(pwd)/litellm_config.yaml:/app/config.yaml \
-e LITELLM_MASTER_KEY=sk-1234 \
-e DATABASE_URL=postgresql://<user>:<password>@<host>:<port>/<dbname> \
-e AZURE_API_KEY=d6*********** \
-e AZURE_API_BASE=https://openai-***********/ \
-p 4000:4000 \

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# ✨ Enterprise Features - Content Mod
# ✨ Enterprise Features - Content Mod, SSO
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
@ -12,16 +12,18 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se
:::
Features:
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
- ✅ Content Moderation with LLM Guard
- ✅ Content Moderation with LlamaGuard
- ✅ Content Moderation with Google Text Moderations
- ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
- ✅ Don't log/store specific requests (eg confidential LLM requests)
- ✅ Don't log/store specific requests to Langfuse, Sentry, etc. (eg confidential LLM requests)
- ✅ Tracking Spend for Custom Tags
## Content Moderation
### Content Moderation with LLM Guard
@ -74,7 +76,7 @@ curl --location 'http://localhost:4000/key/generate' \
# Returns {..'key': 'my-new-key'}
```
**2. Test it!**
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
@ -87,6 +89,76 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
}'
```
#### Turn on/off per request
**1. Update config**
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
llm_guard_mode: "request-specific"
```
**2. Create new key**
```bash
curl --location 'http://localhost:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"models": ["fake-openai-endpoint"],
}'
# Returns {..'key': 'my-new-key'}
```
**3. Test it!**
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
extra_body={ # pass in any provider-specific param, if not supported by openai, https://docs.litellm.ai/docs/completion/input#provider-specific-params
"metadata": {
"permissions": {
"enable_llm_guard_check": True # 👈 KEY CHANGE
},
}
}
)
print(response)
```
</TabItem>
<TabItem value="curl" label="Curl Request">
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
--data '{"model": "fake-openai-endpoint", "messages": [
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "What do you know?"}
]
}'
```
</TabItem>
</Tabs>
### Content Moderation with LlamaGuard

View file

@ -1,4 +1,4 @@
# Load Balancing - Config Setup
# Multiple Instances
Load balance multiple instances of the same model
The proxy will handle routing requests (using LiteLLM's Router). **Set `rpm` in the config if you want maximize throughput**
@ -10,75 +10,6 @@ For more details on routing strategies / params, see [Routing](../routing.md)
:::
## Quick Start - Load Balancing
### Step 1 - Set deployments on config
**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/<your-deployment-name>
api_base: <your-azure-endpoint>
api_key: <your-azure-api-key>
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 6
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-large
api_base: https://openai-france-1234.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 1440
```
### Step 2: Start Proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
### Step 3: Use proxy - Call a model group [Load Balancing]
Curl Command
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
### Usage - Call a specific model deployment
If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model`
In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "azure/gpt-turbo-small-ca",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
## Load Balancing using multiple litellm instances (Kubernetes, Auto Scaling)
LiteLLM Proxy supports sharing rpm/tpm shared across multiple litellm instances, pass `redis_host`, `redis_password` and `redis_port` to enable this. (LiteLLM will use Redis to track rpm/tpm usage )

View file

@ -9,9 +9,9 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
- [Async Custom Callbacks](#custom-callback-class-async)
- [Async Custom Callback APIs](#custom-callback-apis-async)
- [Logging to DataDog](#logging-proxy-inputoutput---datadog)
- [Logging to Langfuse](#logging-proxy-inputoutput---langfuse)
- [Logging to s3 Buckets](#logging-proxy-inputoutput---s3-buckets)
- [Logging to DataDog](#logging-proxy-inputoutput---datadog)
- [Logging to DynamoDB](#logging-proxy-inputoutput---dynamodb)
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
@ -539,6 +539,36 @@ print(response)
</Tabs>
### Team based Logging to Langfuse
**Example:**
This config would send langfuse logs to 2 different langfuse projects, based on the team id
```yaml
litellm_settings:
default_team_settings:
- team_id: my-secret-project
success_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
- team_id: ishaans-secret-project
success_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_2 # Project 2
langfuse_secret: os.environ/LANGFUSE_SECRET_2 # Project 2
```
Now, when you [generate keys](./virtual_keys.md) for this team-id
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"team_id": "ishaans-secret-project"}'
```
All requests made with these keys will log data to their team-specific logging.
## Logging Proxy Input/Output - DataDog
We will use the `--config` to set `litellm.success_callback = ["datadog"]` this will log all successfull LLM calls to DataDog

View file

@ -16,7 +16,7 @@ Expected Performance in Production
| `/chat/completions` Requests/hour | `126K` |
## 1. Switch of Debug Logging
## 1. Switch off Debug Logging
Remove `set_verbose: True` from your config.yaml
```yaml
@ -40,7 +40,7 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
```
## 2. Batch write spend updates every 60s
## 3. Batch write spend updates every 60s
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally.
@ -49,11 +49,35 @@ In production, we recommend using a longer interval period of 60s. This reduces
```yaml
general_settings:
master_key: sk-1234
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
```
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
## 3. Move spend logs to separate server
When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
Recommended to do this for prod:
```yaml
router_settings:
routing_strategy: usage-based-routing-v2
# redis_url: "os.environ/REDIS_URL"
redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT
redis_password: os.environ/REDIS_PASSWORD
```
## 5. Switch off resetting budgets
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
```yaml
general_settings:
disable_reset_budget: true
```
## 6. Move spend logs to separate server (BETA)
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server.
@ -141,24 +165,6 @@ A t2.micro should be sufficient to handle 1k logs / minute on this server.
This consumes at max 120MB, and <0.1 vCPU.
## 4. Switch off resetting budgets
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
```yaml
general_settings:
disable_spend_logs: true
disable_reset_budget: true
```
## 5. Switch of `litellm.telemetry`
Switch of all telemetry tracking done by litellm
```yaml
litellm_settings:
telemetry: False
```
## Machine Specifications to Deploy LiteLLM
| Service | Spec | CPUs | Memory | Architecture | Version|

View file

@ -14,6 +14,7 @@ model_list:
model: gpt-3.5-turbo
litellm_settings:
success_callback: ["prometheus"]
failure_callback: ["prometheus"]
```
Start the proxy
@ -48,6 +49,26 @@ http://localhost:4000/metrics
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_requests_metric` | Number of requests made, per `"user", "key", "model"` |
| `litellm_spend_metric` | Total Spend, per `"user", "key", "model"` |
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model"` |
| `litellm_requests_metric` | Number of requests made, per `"user", "key", "model", "team", "end-user"` |
| `litellm_spend_metric` | Total Spend, per `"user", "key", "model", "team", "end-user"` |
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
## Monitor System Health
To monitor the health of litellm adjacent services (redis / postgres), do:
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
service_callback: ["prometheus_system"]
```
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_redis_latency` | histogram latency for redis calls |
| `litellm_redis_fails` | Number of failed redis calls |
| `litellm_self_latency` | Histogram latency for successful litellm api call |

View file

@ -348,6 +348,29 @@ query_result = embeddings.embed_query(text)
print(f"TITAN EMBEDDINGS")
print(query_result[:5])
```
</TabItem>
<TabItem value="litellm" label="LiteLLM SDK">
This is **not recommended**. There is duplicate logic as the proxy also uses the sdk, which might lead to unexpected errors.
```python
from litellm import completion
response = completion(
model="openai/gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
api_key="anything",
base_url="http://0.0.0.0:4000"
)
print(response)
```
</TabItem>
</Tabs>
@ -438,7 +461,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
),
```
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial.
Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
</TabItem>
<TabItem value="aider" label="Aider">
@ -551,4 +574,3 @@ No Logs
```shell
export LITELLM_LOG=None
```

View file

@ -2,7 +2,9 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Fallbacks, Retries, Timeouts, Cooldowns
# 🔥 Fallbacks, Retries, Timeouts, Load Balancing
Retry call with multiple instances of the same model.
If a call fails after num_retries, fall back to another model group.
@ -10,6 +12,77 @@ If the error is a context window exceeded error, fall back to a larger model gro
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
## Quick Start - Load Balancing
### Step 1 - Set deployments on config
**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/<your-deployment-name>
api_base: <your-azure-endpoint>
api_key: <your-azure-api-key>
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 6
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-large
api_base: https://openai-france-1234.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 1440
```
### Step 2: Start Proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
### Step 3: Use proxy - Call a model group [Load Balancing]
Curl Command
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
### Usage - Call a specific model deployment
If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model`
In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "azure/gpt-turbo-small-ca",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
## Fallbacks + Retries + Timeouts + Cooldowns
**Set via config**
```yaml
model_list:
@ -63,7 +136,158 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
'
```
## Custom Timeouts, Stream Timeouts - Per Model
### Test it!
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data-raw '{
"model": "zephyr-beta", # 👈 MODEL NAME to fallback from
"messages": [
{"role": "user", "content": "what color is red"}
],
"mock_testing_fallbacks": true
}'
```
## Advanced - Context Window Fallbacks
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
[**See Code**](https://github.com/BerriAI/litellm/blob/c9e6b05cfb20dfb17272218e2555d6b496c47f6f/litellm/router.py#L2163)
**1. Setup config**
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with azure/.
<Tabs>
<TabItem value="same-group" label="Same Group">
Filter older instances of a model (e.g. gpt-3.5-turbo) with smaller context windows
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
text = "What is the meaning of 42?" * 5000
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
],
)
print(response)
```
</TabItem>
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
Fallback to larger models if current model is too small.
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo-small
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
- model_name: gpt-3.5-turbo-large
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
- model_name: claude-opus
litellm_params:
model: claude-3-opus-20240229
api_key: os.environ/ANTHROPIC_API_KEY
litellm_settings:
context_window_fallbacks: [{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}]
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
text = "What is the meaning of 42?" * 5000
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
],
)
print(response)
```
</TabItem>
</Tabs>
## Advanced - Custom Timeouts, Stream Timeouts - Per Model
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
```yaml
model_list:
@ -92,7 +316,7 @@ $ litellm --config /path/to/config.yaml
```
## Setting Dynamic Timeouts - Per Request
## Advanced - Setting Dynamic Timeouts - Per Request
LiteLLM Proxy supports setting a `timeout` per request

View file

@ -99,7 +99,7 @@ Now, when you [generate keys](./virtual_keys.md) for this team-id
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-D '{"team_id": "ishaans-secret-project"}'
-d '{"team_id": "ishaans-secret-project"}'
```
All requests made with these keys will log data to their team-specific logging.

View file

@ -9,6 +9,7 @@ Use JWT's to auth admins / projects into the proxy.
This is a new feature, and subject to changes based on feedback.
*UPDATE*: This will be moving to the [enterprise tier](./enterprise.md), once it's out of beta (~by end of April).
:::
## Usage
@ -107,6 +108,34 @@ general_settings:
litellm_jwtauth:
admin_jwt_scope: "litellm-proxy-admin"
```
## Advanced - Spend Tracking (User / Team / Org)
Set the field in the jwt token, which corresponds to a litellm user / team / org.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_jwtauth:
admin_jwt_scope: "litellm-proxy-admin"
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
```
Expected JWT:
```
{
"client_id": "my-unique-team",
"sub": "my-unique-user",
"org_id": "my-unique-org"
}
```
Now litellm will automatically update the spend for the user/team/org in the db for each call.
### JWT Scopes
Here's what scopes on JWT-Auth tokens look like
@ -149,7 +178,7 @@ general_settings:
enable_jwt_auth: True
litellm_jwtauth:
...
team_jwt_scope: "litellm-team" # 👈 Set JWT Scope string
team_id_jwt_field: "litellm-team" # 👈 Set field in the JWT token that stores the team ID
team_allowed_routes: ["/v1/chat/completions"] # 👈 Set accepted routes
```

View file

@ -56,6 +56,9 @@ On accessing the LiteLLM UI, you will be prompted to enter your username, passwo
## ✨ Enterprise Features
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
### Setup SSO/Auth for UI
#### Step 1: Set upperbounds for keys

View file

@ -121,6 +121,9 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
import os
os.environ["OPENAI_API_KEY"] = "anything"
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000",

View file

@ -435,7 +435,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
),
```
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial.
Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
</TabItem>
<TabItem value="aider" label="Aider">
@ -815,4 +815,3 @@ Thread Stats Avg Stdev Max +/- Stdev
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai

View file

@ -95,12 +95,129 @@ print(response)
- `router.image_generation()` - completion calls in OpenAI `/v1/images/generations` endpoint format
- `router.aimage_generation()` - async image generation calls
### Advanced
## Advanced - Routing Strategies
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based
Router provides 4 strategies for routing your calls across multiple deployments:
<Tabs>
<TabItem value="usage-based-v2" label="Rate-Limit Aware v2 (ASYNC)">
**🎉 NEW** This is an async implementation of usage-based-routing.
**Filters out deployment if tpm/rpm limit exceeded** - If you pass in the deployment's tpm/rpm limits.
Routes to **deployment with lowest TPM usage** for that minute.
In production, we use Redis to track usage (TPM/RPM) across multiple deployments. This implementation uses **async redis calls** (redis.incr and redis.mget).
For Azure, your RPM = TPM/6.
<Tabs>
<TabItem value="sdk" label="sdk">
```python
from litellm import Router
model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 100000,
"rpm": 10000,
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 100000,
"rpm": 1000,
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 1000,
}]
router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
enable_pre_call_check=True, # enables router rate limits for concurrent calls
)
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}]
print(response)
```
</TabItem>
<TabItem value="proxy" label="proxy">
**1. Set strategy in config**
```yaml
model_list:
- model_name: gpt-3.5-turbo # model alias
litellm_params: # params for litellm completion/embedding call
model: azure/chatgpt-v-2 # actual model name
api_key: os.environ/AZURE_API_KEY
api_version: os.environ/AZURE_API_VERSION
api_base: os.environ/AZURE_API_BASE
tpm: 100000
rpm: 10000
- model_name: gpt-3.5-turbo
litellm_params: # params for litellm completion/embedding call
model: gpt-3.5-turbo
api_key: os.getenv(OPENAI_API_KEY)
tpm: 100000
rpm: 1000
router_settings:
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
redis_host: <your-redis-host>
redis_password: <your-redis-password>
redis_port: <your-redis-port>
enable_pre_call_check: true
general_settings:
master_key: sk-1234
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
```
**3. Test it!**
```bash
curl --location 'http://localhost:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="latency-based" label="Latency-Based">
@ -117,7 +234,10 @@ import asyncio
model_list = [{ ... }]
# init router
router = Router(model_list=model_list, routing_strategy="latency-based-routing") # 👈 set routing strategy
router = Router(model_list=model_list,
routing_strategy="latency-based-routing",# 👈 set routing strategy
enable_pre_call_check=True, # enables router rate limits for concurrent calls
)
## CALL 1+2
tasks = []
@ -159,7 +279,7 @@ router_settings:
```
</TabItem>
<TabItem value="simple-shuffle" label="(Default) Weighted Pick">
<TabItem value="simple-shuffle" label="(Default) Weighted Pick (Async)">
**Default** Picks a deployment based on the provided **Requests per minute (rpm) or Tokens per minute (tpm)**
@ -257,8 +377,9 @@ router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
routing_strategy="usage-based-routing")
routing_strategy="usage-based-routing"
enable_pre_call_check=True, # enables router rate limits for concurrent calls
)
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}]
@ -555,7 +676,11 @@ router = Router(model_list: Optional[list] = None,
## Pre-Call Checks (Context Window)
Enable pre-call checks to filter out deployments with context window limit < messages for a call.
Enable pre-call checks to filter out:
1. deployments with context window limit < messages for a call.
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
])`)
<Tabs>
<TabItem value="sdk" label="SDK">
@ -567,10 +692,14 @@ from litellm import Router
router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set to True
```
**2. (Azure-only) Set base model**
**2. Set Model List**
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
<Tabs>
<TabItem value="same-group" label="Same Group">
```python
model_list = [
{
@ -582,7 +711,7 @@ model_list = [
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 SET BASE MODEL
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
}
},
{
@ -593,8 +722,51 @@ model_list = [
},
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True)
```
</TabItem>
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
```python
model_list = [
{
"model_name": "gpt-3.5-turbo-small", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
}
},
{
"model_name": "gpt-3.5-turbo-large", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "claude-opus",
"litellm_params": { call
"model": "claude-3-opus-20240229",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
```
</TabItem>
</Tabs>
**3. Test it!**
```python
@ -646,60 +818,9 @@ print(f"response: {response}")
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Setup config**
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with azure/.
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
text = "What is the meaning of 42?" * 5000
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
],
)
print(response)
```
:::info
Go [here](./proxy/reliability.md#advanced---context-window-fallbacks) for how to do this on the proxy
:::
</TabItem>
</Tabs>

View file

@ -310,7 +310,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
),
```
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial.
Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
</TabItem>
<TabItem value="aider" label="Aider">
@ -1351,5 +1351,3 @@ LiteLLM proxy adds **0.00325 seconds** latency as compared to using the Raw Open
```shell
litellm --telemetry False
```

View file

@ -105,6 +105,12 @@ const config = {
label: 'Enterprise',
to: "docs/enterprise"
},
{
sidebarId: 'tutorialSidebar',
position: 'left',
label: '🚀 Hosted',
to: "docs/hosted"
},
{
href: 'https://github.com/BerriAI/litellm',
label: 'GitHub',

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 298 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 398 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 496 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 348 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 460 KiB

View file

@ -31,24 +31,26 @@ const sidebars = {
"proxy/quick_start",
"proxy/deploy",
"proxy/prod",
"proxy/configs",
{
type: "link",
label: "📖 All Endpoints",
label: "📖 All Endpoints (Swagger)",
href: "https://litellm-api.up.railway.app/",
},
"proxy/enterprise",
"proxy/user_keys",
"proxy/virtual_keys",
"proxy/demo",
"proxy/configs",
"proxy/reliability",
"proxy/users",
"proxy/user_keys",
"proxy/enterprise",
"proxy/virtual_keys",
"proxy/team_based_routing",
"proxy/ui",
"proxy/cost_tracking",
"proxy/token_auth",
{
type: "category",
label: "🔥 Load Balancing",
items: ["proxy/load_balancing", "proxy/reliability"],
label: "Extra Load Balancing",
items: ["proxy/load_balancing"],
},
"proxy/model_management",
"proxy/health",
@ -61,7 +63,7 @@ const sidebars = {
label: "Logging, Alerting",
items: ["proxy/logging", "proxy/alerting", "proxy/streaming_logging"],
},
"proxy/grafana_metrics",
"proxy/prometheus",
"proxy/call_hooks",
"proxy/rules",
"proxy/cli",
@ -84,6 +86,7 @@ const sidebars = {
"completion/stream",
"completion/message_trimming",
"completion/function_call",
"completion/vision",
"completion/model_alias",
"completion/batching",
"completion/mock_requests",
@ -113,6 +116,7 @@ const sidebars = {
},
items: [
"providers/openai",
"providers/text_completion_openai",
"providers/openai_compatible",
"providers/azure",
"providers/azure_ai",
@ -162,7 +166,6 @@ const sidebars = {
"debugging/local_debugging",
"observability/callbacks",
"observability/custom_callback",
"observability/lunary_integration",
"observability/langfuse_integration",
"observability/sentry",
"observability/promptlayer_integration",
@ -171,6 +174,8 @@ const sidebars = {
"observability/slack_integration",
"observability/traceloop_integration",
"observability/athina_integration",
"observability/lunary_integration",
"observability/athina_integration",
"observability/helicone_integration",
"observability/supabase_integration",
`observability/telemetry`,

View file

@ -16,7 +16,7 @@ However, we also expose 3 public helper functions to calculate token usage acros
```python
from litellm import token_counter
messages = [{"user": "role", "content": "Hey, how's it going"}]
messages = [{"role": "user", "content": "Hey, how's it going"}]
print(token_counter(model="gpt-3.5-turbo", messages=messages))
```

View file

@ -95,7 +95,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
traceback.print_exc()
raise e
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:
if self.llm_guard_mode == "key-specific":
# check if llm guard enabled for specific keys only
self.print_verbose(
@ -108,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
return True
elif self.llm_guard_mode == "all":
return True
elif self.llm_guard_mode == "request-specific":
self.print_verbose(f"received metadata: {data.get('metadata', {})}")
metadata = data.get("metadata", {})
permissions = metadata.get("permissions", {})
if (
"enable_llm_guard_check" in permissions
and permissions["enable_llm_guard_check"] == True
):
return True
return False
async def async_moderation_hook(
@ -126,7 +135,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
)
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data)
if _proceed == False:
return

View file

@ -1,5 +1,6 @@
# Enterprise Proxy Util Endpoints
from litellm._logging import verbose_logger
import collections
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
@ -17,6 +18,48 @@ async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
return response
async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
response = await prisma_client.db.query_raw(
"""
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
DATE(s."startTime") AS spend_date,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" s
WHERE s."startTime" >= current_date - interval '30 days'
GROUP BY individual_request_tag, spend_date
ORDER BY spend_date;
"""
)
# print("tags - spend")
# print(response)
# Bar Chart 1 - Spend per tag - Top 10 tags by spend
total_spend_per_tag = collections.defaultdict(float)
total_requests_per_tag = collections.defaultdict(int)
for row in response:
tag_name = row["individual_request_tag"]
tag_spend = row["total_spend"]
total_spend_per_tag[tag_name] += tag_spend
total_requests_per_tag[tag_name] += row["log_count"]
sorted_tags = sorted(total_spend_per_tag.items(), key=lambda x: x[1], reverse=True)
# convert to ui format
ui_tags = []
for tag in sorted_tags:
ui_tags.append(
{
"name": tag[0],
"value": tag[1],
"log_count": total_requests_per_tag[tag[0]],
}
)
return {"top_10_tags": ui_tags}
async def view_spend_logs_from_clickhouse(
api_key=None, user_id=None, request_id=None, start_date=None, end_date=None
):

View file

@ -6,7 +6,7 @@
"": {
"dependencies": {
"@hono/node-server": "^1.9.0",
"hono": "^4.1.5"
"hono": "^4.2.7"
},
"devDependencies": {
"@types/node": "^20.11.17",
@ -463,9 +463,9 @@
}
},
"node_modules/hono": {
"version": "4.1.5",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.1.5.tgz",
"integrity": "sha512-3ChJiIoeCxvkt6vnkxJagplrt1YZg3NyNob7ssVeK2PUqEINp4q1F94HzFnvY9QE8asVmbW5kkTDlyWylfg2vg==",
"version": "4.2.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
"integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
"engines": {
"node": ">=16.0.0"
}

View file

@ -4,7 +4,7 @@
},
"dependencies": {
"@hono/node-server": "^1.9.0",
"hono": "^4.1.5"
"hono": "^4.2.7"
},
"devDependencies": {
"@types/node": "^20.11.17",

View file

@ -3,7 +3,11 @@ import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.caching import Cache
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
from litellm.proxy._types import (
KeyManagementSystem,
KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams,
)
import httpx
import dotenv
@ -12,10 +16,24 @@ dotenv.load_dotenv()
if set_verbose == True:
_turn_on_debug()
#############################################
### Callbacks /Logging / Success / Failure Handlers ###
input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_langfuse_default_tags: Optional[
List[
Literal[
"user_api_key_alias",
"user_api_key_user_id",
"user_api_key_user_email",
"user_api_key_team_alias",
"semantic-similarity",
"proxy_base_url",
]
]
] = None
_async_input_callback: List[Callable] = (
[]
) # internal variable - async custom callbacks are routed here.
@ -27,6 +45,8 @@ _async_failure_callback: List[Callable] = (
) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
## end of callbacks #############
email: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
@ -46,6 +66,7 @@ replicate_key: Optional[str] = None
cohere_key: Optional[str] = None
maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None
ollama_key: Optional[str] = None
openrouter_key: Optional[str] = None
huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None
@ -56,6 +77,7 @@ baseten_key: Optional[str] = None
aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None
use_client: bool = False
ssl_verify: bool = True
disable_streaming_logging: bool = False
### GUARDRAILS ###
llamaguard_model_name: Optional[str] = None
@ -64,7 +86,7 @@ google_moderation_confidence_threshold: Optional[float] = None
llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific"] = "all"
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
##################
logging: bool = True
caching: bool = (
@ -76,6 +98,8 @@ caching_with_models: bool = (
cache: Optional[Cache] = (
None # cache object <- use this - https://docs.litellm.ai/docs/caching
)
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
@ -170,7 +194,7 @@ dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
generic_logger_headers: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
default_user_params: Optional[Dict] = None
default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None
@ -258,6 +282,7 @@ open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []
cohere_models: List = []
cohere_chat_models: List = []
mistral_chat_models: List = []
anthropic_models: List = []
openrouter_models: List = []
vertex_language_models: List = []
@ -267,6 +292,7 @@ vertex_code_chat_models: List = []
vertex_text_models: List = []
vertex_code_text_models: List = []
vertex_embedding_models: List = []
vertex_anthropic_models: List = []
ai21_models: List = []
nlp_cloud_models: List = []
aleph_alpha_models: List = []
@ -282,6 +308,8 @@ for key, value in model_cost.items():
cohere_models.append(key)
elif value.get("litellm_provider") == "cohere_chat":
cohere_chat_models.append(key)
elif value.get("litellm_provider") == "mistral":
mistral_chat_models.append(key)
elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key)
elif value.get("litellm_provider") == "openrouter":
@ -300,6 +328,9 @@ for key, value in model_cost.items():
vertex_code_chat_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-embedding-models":
vertex_embedding_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
key = key.replace("vertex_ai/", "")
vertex_anthropic_models.append(key)
elif value.get("litellm_provider") == "ai21":
ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud":
@ -346,7 +377,7 @@ replicate_models: List = [
"replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
"joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe",
# Flan T-5
"daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f"
"daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f",
# Others
"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
@ -447,6 +478,7 @@ model_list = (
+ deepinfra_models
+ perplexity_models
+ maritalk_models
+ vertex_language_models
)
provider_list: List = [
@ -568,6 +600,7 @@ from .utils import (
completion_cost,
supports_function_calling,
supports_parallel_function_calling,
supports_vision,
get_litellm_params,
Logging,
acreate,
@ -585,6 +618,7 @@ from .utils import (
_should_retry,
get_secret,
get_supported_openai_params,
get_api_base,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
@ -600,6 +634,7 @@ from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig
from .llms.vertex_ai import VertexAIConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig

View file

@ -32,6 +32,25 @@ def _get_redis_kwargs():
return available_args
def _get_redis_url_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
# Only allow primitive arguments
exclude_args = {
"self",
"connection_pool",
"retry",
}
include_args = ["url"]
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
@ -91,27 +110,39 @@ def _get_redis_client_logic(**env_overrides):
redis_kwargs.pop("password", None)
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop(
"connection_pool", None
) # redis.from_url doesn't support setting your own connection pool
return redis.Redis.from_url(**redis_kwargs)
args = _get_redis_url_kwargs()
url_kwargs = {}
for arg in redis_kwargs:
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs)
return redis.Redis(**redis_kwargs)
def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop(
"connection_pool", None
) # redis.from_url doesn't support setting your own connection pool
return async_redis.Redis.from_url(**redis_kwargs)
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
url_kwargs = {}
for arg in redis_kwargs:
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
else:
litellm.print_verbose(
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
arg
)
)
return async_redis.Redis.from_url(**url_kwargs)
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
@ -124,4 +155,9 @@ def get_redis_connection_pool(**env_overrides):
return async_redis.BlockingConnectionPool.from_url(
timeout=5, url=redis_kwargs["url"]
)
connection_class = async_redis.Connection
if "ssl" in redis_kwargs and redis_kwargs["ssl"] is not None:
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

130
litellm/_service_logger.py Normal file
View file

@ -0,0 +1,130 @@
import litellm, traceback
from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger
from datetime import timedelta
from typing import Union
class ServiceLogging(CustomLogger):
"""
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
"""
def __init__(self, mock_testing: bool = False) -> None:
self.mock_testing = mock_testing
self.mock_testing_sync_success_hook = 0
self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger()
def service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
if self.mock_testing:
self.mock_testing_sync_success_hook += 1
def service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
if self.mock_testing:
self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
"""
- For counting if the redis, postgres call is successful
"""
if self.mock_testing:
self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload(
is_error=False,
error=None,
service=service,
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.prometheusServicesLogger.async_service_success_hook(
payload=payload
)
async def async_service_failure_hook(
self,
service: ServiceTypes,
duration: float,
error: Union[str, Exception],
call_type: str,
):
"""
- For counting if the redis, postgres call is unsuccessful
"""
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
error_message = ""
if isinstance(error, Exception):
error_message = str(error)
elif isinstance(error, str):
error_message = error
payload = ServiceLoggerPayload(
is_error=True,
error=error_message,
service=service,
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
if self.prometheusServicesLogger is None:
self.prometheusServicesLogger = self.prometheusServicesLogger()
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload
)
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Hook to track failed litellm-service calls
"""
return await super().async_post_call_failure_hook(
original_exception, user_api_key_dict
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Hook to track latency for litellm proxy llm api calls
"""
try:
_duration = end_time - start_time
if isinstance(_duration, timedelta):
_duration = _duration.total_seconds()
elif isinstance(_duration, float):
pass
else:
raise Exception(
"Duration={} is not a float or timedelta object. type={}".format(
_duration, type(_duration)
)
) # invalid _duration value
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs["call_type"],
)
except Exception as e:
raise e

View file

@ -13,6 +13,7 @@ import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback
@ -81,9 +82,37 @@ class InMemoryCache(BaseCache):
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
@ -109,6 +138,8 @@ class RedisCache(BaseCache):
**kwargs,
):
from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis
redis_kwargs = {}
if host is not None:
@ -118,10 +149,19 @@ class RedisCache(BaseCache):
if password is not None:
redis_kwargs["password"] = password
### HEALTH MONITORING OBJECT ###
if kwargs.get("service_logger_obj", None) is not None and isinstance(
kwargs["service_logger_obj"], ServiceLogging
):
self.service_logger_obj = kwargs.pop("service_logger_obj")
else:
self.service_logger_obj = ServiceLogging()
redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
# redis namespaces
self.namespace = namespace
# for high traffic, we store the redis results in memory and then batch write to redis
@ -133,6 +173,16 @@ class RedisCache(BaseCache):
except Exception as e:
pass
### ASYNC HEALTH PING ###
try:
# asyncio.get_running_loop().create_task(self.ping())
result = asyncio.get_running_loop().create_task(self.ping())
except Exception:
pass
### SYNC HEALTH PING ###
self.redis_client.ping()
def init_async_client(self):
from ._redis import get_redis_async_client
@ -163,18 +213,101 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
)
def increment_cache(self, key, value: int, **kwargs) -> int:
_redis_client = self.redis_client
start_time = time.time()
try:
result = _redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="increment_cache",
)
)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="increment_cache",
)
)
verbose_logger.error(
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
start_time = time.time()
try:
keys = []
_redis_client = self.init_async_client()
async with _redis_client as redis_client:
async for key in redis_client.scan_iter(match=pattern + "*", count=count):
async for key in redis_client.scan_iter(
match=pattern + "*", count=count
):
keys.append(key)
if len(keys) >= count:
break
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_scan_iter",
)
) # DO NOT SLOW DOWN CALL B/C OF THIS
return keys
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_scan_iter",
)
)
raise e
async def async_set_cache(self, key, value, **kwargs):
start_time = time.time()
try:
_redis_client = self.init_async_client()
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None)
@ -186,7 +319,26 @@ class RedisCache(BaseCache):
print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache",
)
)
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache",
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
@ -200,6 +352,11 @@ class RedisCache(BaseCache):
Use Redis Pipelines for bulk write operations
"""
_redis_client = self.init_async_client()
start_time = time.time()
print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
@ -219,8 +376,30 @@ class RedisCache(BaseCache):
print_verbose(f"pipeline results: {results}")
# Optionally, you could process 'results' to make sure that all set operations were successful.
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache_pipeline",
)
)
return results
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache_pipeline",
)
)
verbose_logger.error(
"LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s",
str(e),
@ -235,7 +414,44 @@ class RedisCache(BaseCache):
key = self.check_and_fix_namespace(key=key)
self.redis_batch_writing_buffer.append((key, value))
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer()
await self.flush_cache_buffer() # logging done in here
async def async_increment(self, key, value: int, **kwargs) -> int:
_redis_client = self.init_async_client()
start_time = time.time()
try:
async with _redis_client as redis_client:
result = await redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_increment",
)
)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_increment",
)
)
verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
raise e
async def flush_cache_buffer(self):
print_verbose(
@ -274,40 +490,17 @@ class RedisCache(BaseCache):
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
try:
print_verbose(f"Get Async Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print_verbose(
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
)
async def async_get_cache_pipeline(self, key_list) -> dict:
def batch_get_cache(self, key_list) -> dict:
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
key_value_dict = {}
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
# Queue the get operations in the pipeline for all keys.
_keys = []
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
pipe.get(cache_key) # Queue GET command in pipeline
# Execute the pipeline and await the results.
results = await pipe.execute()
_keys.append(cache_key)
results = self.redis_client.mget(keys=_keys)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
@ -323,21 +516,185 @@ class RedisCache(BaseCache):
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
async def ping(self):
async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
start_time = time.time()
async with _redis_client as redis_client:
print_verbose(f"Pinging Async Redis Cache")
try:
response = await redis_client.ping()
print_verbose(f"Get Async Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_get_cache",
)
)
return response
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_get_cache",
)
)
# NON blocking - notify users Redis is throwing an exception
print_verbose(
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
)
async def async_batch_get_cache(self, key_list) -> dict:
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
key_value_dict = {}
start_time = time.time()
try:
async with _redis_client as redis_client:
_keys = []
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = await redis_client.mget(keys=_keys)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_batch_get_cache",
)
)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
decoded_results = {}
for k, v in key_value_dict.items():
if isinstance(k, bytes):
k = k.decode("utf-8")
v = self._get_cache_logic(v)
decoded_results[k] = v
return decoded_results
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_batch_get_cache",
)
)
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
def sync_ping(self) -> bool:
"""
Tests if the sync redis client is correctly setup.
"""
print_verbose(f"Pinging Sync Redis Cache")
start_time = time.time()
try:
response = self.redis_client.ping()
print_verbose(f"Redis Cache PING: {response}")
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="sync_ping",
)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
self.service_logger_obj.service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="sync_ping",
)
print_verbose(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
traceback.print_exc()
raise e
async def ping(self) -> bool:
_redis_client = self.init_async_client()
start_time = time.time()
async with _redis_client as redis_client:
print_verbose(f"Pinging Async Redis Cache")
try:
response = await redis_client.ping()
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_ping",
)
)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_ping",
)
)
print_verbose(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
traceback.print_exc()
raise e
async def delete_cache_keys(self, keys):
_redis_client = self.init_async_client()
# keys is a list, unpack it so it gets passed as individual elements to delete
async with _redis_client as redis_client:
await redis_client.delete(*keys)
def client_list(self):
client_list = self.redis_client.client_list()
return client_list
def info(self):
info = self.redis_client.info()
return info
def flush_cache(self):
self.redis_client.flushall()
@ -828,8 +1185,10 @@ class DualCache(BaseCache):
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
self.default_in_memory_ttl = default_in_memory_ttl
self.default_redis_ttl = default_redis_ttl
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl
)
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
def set_cache(self, key, value, local_only: bool = False, **kwargs):
# Update both Redis and in-memory cache
@ -846,6 +1205,30 @@ class DualCache(BaseCache):
except Exception as e:
print_verbose(e)
def increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e
def get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first
try:
@ -872,6 +1255,39 @@ class DualCache(BaseCache):
except Exception as e:
traceback.print_exc()
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
for key, value in redis_result.items():
result[keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first
try:
@ -905,7 +1321,50 @@ class DualCache(BaseCache):
except Exception as e:
traceback.print_exc()
async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs
):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
keys, **kwargs
)
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, **kwargs
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key, value in redis_result.items():
if value is not None:
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
for key, value in redis_result.items():
index = keys.index(key)
result[index] = value
return result
except Exception as e:
traceback.print_exc()
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose(
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
)
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
@ -916,6 +1375,32 @@ class DualCache(BaseCache):
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
async def async_increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = await self.in_memory_cache.async_increment(
key, value, **kwargs
)
if self.redis_cache is not None and local_only == False:
result = await self.redis_cache.async_increment(key, value, **kwargs)
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e
def flush_cache(self):
if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache()
@ -939,6 +1424,8 @@ class Cache:
password: Optional[str] = None,
namespace: Optional[str] = None,
ttl: Optional[float] = None,
default_in_memory_ttl: Optional[float] = None,
default_in_redis_ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None,
supported_call_types: Optional[
List[
@ -1038,6 +1525,14 @@ class Cache:
self.redis_flush_size = redis_flush_size
self.ttl = ttl
if self.type == "local" and default_in_memory_ttl is not None:
self.ttl = default_in_memory_ttl
if (
self.type == "redis" or self.type == "redis-semantic"
) and default_in_redis_ttl is not None:
self.ttl = default_in_redis_ttl
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
@ -1379,6 +1874,11 @@ class Cache:
return await self.cache.ping()
return None
async def delete_cache_keys(self, keys):
if hasattr(self.cache, "delete_cache_keys"):
return await self.cache.delete_cache_keys(keys)
return None
async def disconnect(self):
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()

View file

@ -82,14 +82,18 @@ class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 408
self.message = message
self.model = model
self.llm_provider = llm_provider
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__(
request=request
) # Call the base class constructor with the parameters it needs
self.status_code = 408
self.message = message
self.model = model
self.llm_provider = llm_provider
# custom function to convert to str
def __str__(self):
return str(self.message)
class PermissionDeniedError(PermissionDeniedError): # type:ignore

View file

@ -6,7 +6,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union
from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### PRE-CALL CHECKS - router/proxy only ####
"""
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
"""
async def async_pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
def pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
#### CALL HOOKS - proxy only ####
"""
Control the modify incoming / outgoung data before calling the model

View file

@ -0,0 +1,51 @@
import requests
import json
import traceback
from datetime import datetime, timezone
class GreenscaleLogger:
def __init__(self):
import os
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
self.headers = {
"api-key": self.greenscale_api_key,
"Content-Type": "application/json"
}
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try:
response_json = response_obj.model_dump() if response_obj else {}
data = {
"modelId": kwargs.get("model"),
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
"outputTokenCount": response_json.get("usage", {}).get("completion_tokens"),
}
data["timestamp"] = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
if type(end_time) == datetime and type(start_time) == datetime:
data["invocationLatency"] = int((end_time - start_time).total_seconds() * 1000)
# Add additional metadata keys to tags
tags = []
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
for key, value in metadata.items():
if key.startswith("greenscale"):
if key == "greenscale_project":
data["project"] = value
elif key == "greenscale_application":
data["application"] = value
else:
tags.append({"key": key.replace("greenscale_", ""), "value": str(value)})
data["tags"] = tags
response = requests.post(self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str))
if response.status_code != 200:
print_verbose(f"Greenscale Logger Error - {response.text}, {response.status_code}")
else:
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
except Exception as e:
print_verbose(f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}")
pass

View file

@ -17,7 +17,7 @@ class LangFuseLogger:
from langfuse import Langfuse
except Exception as e:
raise Exception(
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\033[0m"
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
)
# Instance variables
self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
@ -34,6 +34,14 @@ class LangFuseLogger:
flush_interval=1, # flush interval in seconds
)
# set the current langfuse project id in the environ
# this is used by Alerting to link to the correct project
try:
project_id = self.Langfuse.client.projects.get().data[0].id
os.environ["LANGFUSE_PROJECT_ID"] = project_id
except:
project_id = None
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
self.upstream_langfuse_secret_key = os.getenv(
"UPSTREAM_LANGFUSE_SECRET_KEY"
@ -76,6 +84,7 @@ class LangFuseLogger:
print_verbose(
f"Langfuse Logging - Enters logging function for model {kwargs}"
)
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
@ -118,6 +127,11 @@ class LangFuseLogger:
):
input = prompt
output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = response_obj.choices[0].text
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
@ -128,6 +142,7 @@ class LangFuseLogger:
self._log_langfuse_v2(
user_id,
metadata,
litellm_params,
output,
start_time,
end_time,
@ -156,7 +171,7 @@ class LangFuseLogger:
verbose_logger.info(f"Langfuse Layer Logging - logging success")
except:
traceback.print_exc()
print(f"Langfuse Layer Error - {traceback.format_exc()}")
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
pass
async def _async_log_event(
@ -185,7 +200,7 @@ class LangFuseLogger:
):
from langfuse.model import CreateTrace, CreateGeneration
print(
verbose_logger.warning(
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
)
@ -219,6 +234,7 @@ class LangFuseLogger:
self,
user_id,
metadata,
litellm_params,
output,
start_time,
end_time,
@ -273,13 +289,13 @@ class LangFuseLogger:
clean_metadata = {}
if isinstance(metadata, dict):
for key, value in metadata.items():
# generate langfuse tags
if key in [
"user_api_key",
"user_api_key_user_id",
"user_api_key_team_id",
"semantic-similarity",
]:
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
and key in litellm._langfuse_default_tags
):
tags.append(f"{key}:{value}")
# clean litellm metadata before logging
@ -293,13 +309,55 @@ class LangFuseLogger:
else:
clean_metadata[key] = value
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
and "proxy_base_url" in litellm._langfuse_default_tags
):
proxy_base_url = os.environ.get("PROXY_BASE_URL", None)
if proxy_base_url is not None:
tags.append(f"proxy_base_url:{proxy_base_url}")
api_base = litellm_params.get("api_base", None)
if api_base:
clean_metadata["api_base"] = api_base
vertex_location = kwargs.get("vertex_location", None)
if vertex_location:
clean_metadata["vertex_location"] = vertex_location
aws_region_name = kwargs.get("aws_region_name", None)
if aws_region_name:
clean_metadata["aws_region_name"] = aws_region_name
if supports_tags:
if "cache_hit" in kwargs:
if kwargs["cache_hit"] is None:
kwargs["cache_hit"] = False
tags.append(f"cache_hit:{kwargs['cache_hit']}")
clean_metadata["cache_hit"] = kwargs["cache_hit"]
trace_params.update({"tags": tags})
proxy_server_request = litellm_params.get("proxy_server_request", None)
if proxy_server_request:
method = proxy_server_request.get("method", None)
url = proxy_server_request.get("url", None)
headers = proxy_server_request.get("headers", None)
clean_headers = {}
if headers:
for key, value in headers.items():
# these headers can leak our API keys and/or JWT tokens
if key.lower() not in ["authorization", "cookie", "referer"]:
clean_headers[key] = value
clean_metadata["request"] = {
"method": method,
"url": url,
"headers": clean_headers,
}
print_verbose(f"trace_params: {trace_params}")
trace = self.Langfuse.trace(**trace_params)
generation_id = None
@ -316,13 +374,21 @@ class LangFuseLogger:
# just log `litellm-{call_type}` as the generation name
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
if response_obj is not None and "system_fingerprint" in response_obj:
system_fingerprint = response_obj.get("system_fingerprint", None)
else:
system_fingerprint = None
if system_fingerprint is not None:
optional_params["system_fingerprint"] = system_fingerprint
generation_params = {
"name": generation_name,
"id": metadata.get("generation_id", generation_id),
"startTime": start_time,
"endTime": end_time,
"start_time": start_time,
"end_time": end_time,
"model": kwargs["model"],
"modelParameters": optional_params,
"model_parameters": optional_params,
"input": input,
"output": output,
"usage": usage,
@ -334,13 +400,15 @@ class LangFuseLogger:
generation_params["prompt"] = metadata.get("prompt", None)
if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["statusMessage"] = output
generation_params["status_message"] = output
if supports_completion_start_time:
generation_params["completion_start_time"] = kwargs.get(
"completion_start_time", None
)
print_verbose(f"generation_params: {generation_params}")
trace.generation(**generation_params)
except Exception as e:
print(f"Langfuse Layer Error - {traceback.format_exc()}")
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")

View file

@ -7,6 +7,19 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import asyncio
import types
from pydantic import BaseModel
def is_serializable(value):
non_serializable_types = (
types.CoroutineType,
types.FunctionType,
types.GeneratorType,
BaseModel,
)
return not isinstance(value, non_serializable_types)
class LangsmithLogger:
@ -21,7 +34,9 @@ class LangsmithLogger:
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
metadata = kwargs.get('litellm_params', {}).get("metadata", {}) or {} # if metadata is None
metadata = (
kwargs.get("litellm_params", {}).get("metadata", {}) or {}
) # if metadata is None
# set project name and run_name for langsmith logging
# users can pass project_name and run name to litellm.completion()
@ -51,24 +66,46 @@ class LangsmithLogger:
new_kwargs = {}
for key in kwargs:
value = kwargs[key]
if key == "start_time" or key == "end_time":
if key == "start_time" or key == "end_time" or value is None:
pass
elif type(value) != dict:
elif type(value) == datetime.datetime:
new_kwargs[key] = value.isoformat()
elif type(value) != dict and is_serializable(value=value):
new_kwargs[key] = value
requests.post(
"https://api.smith.langchain.com/runs",
json={
print(f"type of response: {type(response_obj)}")
for k, v in new_kwargs.items():
print(f"key={k}, type of arg: {type(v)}, value={v}")
if isinstance(response_obj, BaseModel):
try:
response_obj = response_obj.model_dump()
except:
response_obj = response_obj.dict() # type: ignore
print(f"response_obj: {response_obj}")
data = {
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": {**new_kwargs},
"outputs": response_obj.json(),
"inputs": new_kwargs,
"outputs": response_obj,
"session_name": project_name,
"start_time": start_time,
"end_time": end_time,
},
}
print(f"data: {data}")
response = requests.post(
"https://api.smith.langchain.com/runs",
json=data,
headers={"x-api-key": self.langsmith_api_key},
)
if response.status_code >= 300:
print_verbose(f"Error: {response.status_code}")
else:
print_verbose("Run successfully created")
print_verbose(
f"Langsmith Layer Logging - final response object: {response_obj}"
)

View file

@ -4,11 +4,13 @@ from datetime import datetime, timezone
import traceback
import dotenv
import importlib
from pkg_resources import parse_version
import sys
import packaging
dotenv.load_dotenv()
# convert to {completion: xx, tokens: xx}
def parse_usage(usage):
return {
@ -16,6 +18,7 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
}
def parse_messages(input):
if input is None:
return None
@ -28,7 +31,6 @@ def parse_messages(input):
if "message" in message:
return clean_message(message["message"])
serialized = {
"role": message.get("role"),
"content": message.get("content"),
@ -56,10 +58,13 @@ class LunaryLogger:
def __init__(self):
try:
import lunary
version = importlib.metadata.version("lunary")
# if version < 0.1.43 then raise ImportError
if parse_version(version) < parse_version("0.1.43"):
print("Lunary version outdated. Required: > 0.1.43. Upgrade via 'pip install lunary --upgrade'")
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
print(
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
)
raise ImportError
self.lunary_client = lunary
@ -88,9 +93,7 @@ class LunaryLogger:
print_verbose(f"Lunary Logging - Logging request for model {model}")
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
)
metadata = litellm_params.get("metadata", {}) or {}
tags = litellm_params.pop("tags", None) or []
@ -148,7 +151,7 @@ class LunaryLogger:
runtime="litellm",
error=error_obj,
output=parse_messages(output),
token_usage=usage
token_usage=usage,
)
except:

View file

@ -1,6 +1,6 @@
# used for /metrics endpoint on LiteLLM Proxy
#### What this does ####
# On success + failure, log events to Supabase
# On success, log events to Prometheus
import dotenv, os
import requests
@ -19,27 +19,33 @@ class PrometheusLogger:
**kwargs,
):
try:
verbose_logger.debug(f"in init prometheus metrics")
print(f"in init prometheus metrics")
from prometheus_client import Counter
self.litellm_llm_api_failed_requests_metric = Counter(
name="litellm_llm_api_failed_requests_metric",
documentation="Total number of failed LLM API calls via litellm",
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
self.litellm_requests_metric = Counter(
name="litellm_requests_metric",
documentation="Total number of LLM calls to litellm",
labelnames=["user", "key", "model"],
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
# Counter for spend
self.litellm_spend_metric = Counter(
"litellm_spend_metric",
"Total spend on LLM requests",
labelnames=["user", "key", "model"],
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
# Counter for total_output_tokens
self.litellm_tokens_metric = Counter(
"litellm_total_tokens",
"Total number of input + output tokens from LLM requests",
labelnames=["user", "key", "model"],
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
@ -61,24 +67,50 @@ class PrometheusLogger:
# unpack kwargs
model = kwargs.get("model", "")
response_cost = kwargs.get("response_cost", 0.0)
response_cost = kwargs.get("response_cost", 0.0) or 0
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params.get("metadata", {}).get(
"user_api_key_user_id", None
)
user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None)
user_api_team = litellm_params.get("metadata", {}).get(
"user_api_key_team_id", None
)
if response_obj is not None:
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0)
else:
tokens_used = 0
print_verbose(
f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}"
)
self.litellm_requests_metric.labels(end_user_id, user_api_key, model).inc()
self.litellm_spend_metric.labels(end_user_id, user_api_key, model).inc(
response_cost
)
self.litellm_tokens_metric.labels(end_user_id, user_api_key, model).inc(
tokens_used
)
if (
user_api_key is not None
and isinstance(user_api_key, str)
and user_api_key.startswith("sk-")
):
from litellm.proxy.utils import hash_token
user_api_key = hash_token(user_api_key)
self.litellm_requests_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc()
self.litellm_spend_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc(response_cost)
self.litellm_tokens_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc(tokens_used)
### FAILURE INCREMENT ###
if "exception" in kwargs:
self.litellm_llm_api_failed_requests_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc()
except Exception as e:
traceback.print_exc()
verbose_logger.debug(

View file

@ -0,0 +1,198 @@
# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy
#### What this does ####
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose, verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
class PrometheusServicesLogger:
# Class variables or attributes
litellm_service_latency = None # Class-level attribute to store the Histogram
def __init__(
self,
mock_testing: bool = False,
**kwargs,
):
try:
try:
from prometheus_client import Counter, Histogram, REGISTRY
except ImportError:
raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`"
)
self.Histogram = Histogram
self.Counter = Counter
self.REGISTRY = REGISTRY
verbose_logger.debug(f"in init prometheus services metrics")
self.services = [item.value for item in ServiceTypes]
self.payload_to_prometheus_map = (
{}
) # store the prometheus histogram/counter we need to call for each field in payload
for service in self.services:
histogram = self.create_histogram(service, type_of_request="latency")
counter_failed_request = self.create_counter(
service, type_of_request="failed_requests"
)
counter_total_requests = self.create_counter(
service, type_of_request="total_requests"
)
self.payload_to_prometheus_map[service] = [
histogram,
counter_failed_request,
counter_total_requests,
]
self.prometheus_to_amount_map: dict = (
{}
) # the field / value in ServiceLoggerPayload the object needs to be incremented by
### MOCK TESTING ###
self.mock_testing = mock_testing
self.mock_testing_success_calls = 0
self.mock_testing_failure_calls = 0
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
def is_metric_registered(self, metric_name) -> bool:
for metric in self.REGISTRY.collect():
if metric_name == metric.name:
return True
return False
def get_metric(self, metric_name):
for metric in self.REGISTRY.collect():
for sample in metric.samples:
if metric_name == sample.name:
return metric
return None
def create_histogram(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Histogram(
metric_name,
"Latency for {} service".format(service),
labelnames=[service],
)
def create_counter(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Counter(
metric_name,
"Total {} for {} service".format(type_of_request, service),
labelnames=[service],
)
def observe_histogram(
self,
histogram,
labels: str,
amount: float,
):
assert isinstance(histogram, self.Histogram)
histogram.labels(labels).observe(amount)
def increment_counter(
self,
counter,
labels: str,
amount: float,
):
assert isinstance(counter, self.Counter)
counter.labels(labels).inc(amount)
def service_success_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_success_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Histogram):
self.observe_histogram(
histogram=obj,
labels=payload.service.value,
amount=payload.duration,
)
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
def service_failure_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_failure_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Counter):
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG ERROR COUNT / TOTAL REQUESTS TO PROMETHEUS
)
async def async_service_success_hook(self, payload: ServiceLoggerPayload):
"""
Log successful call to prometheus
"""
if self.mock_testing:
self.mock_testing_success_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Histogram):
self.observe_histogram(
histogram=obj,
labels=payload.service.value,
amount=payload.duration,
)
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
print(f"received error payload: {payload.error}")
if self.mock_testing:
self.mock_testing_failure_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Counter):
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG ERROR COUNT TO PROMETHEUS
)

View file

@ -0,0 +1,486 @@
#### What this does ####
# Class for sending Slack Alerts #
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import copy
import traceback
from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm
from typing import List, Literal, Any, Union, Optional, Dict
from litellm.caching import DualCache
import asyncio
import aiohttp
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
class SlackAlerting:
# Class variables or attributes
def __init__(
self,
alerting_threshold: float = 300,
alerting: Optional[List] = [],
alert_types: Optional[
List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
]
] = [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
],
alert_to_webhook_url: Optional[
Dict
] = None, # if user wants to separate alerts to diff channels
):
self.alerting_threshold = alerting_threshold
self.alerting = alerting
self.alert_types = alert_types
self.internal_usage_cache = DualCache()
self.async_http_handler = AsyncHTTPHandler()
self.alert_to_webhook_url = alert_to_webhook_url
pass
def update_values(
self,
alerting: Optional[List] = None,
alerting_threshold: Optional[float] = None,
alert_types: Optional[List] = None,
alert_to_webhook_url: Optional[Dict] = None,
):
if alerting is not None:
self.alerting = alerting
if alerting_threshold is not None:
self.alerting_threshold = alerting_threshold
if alert_types is not None:
self.alert_types = alert_types
if alert_to_webhook_url is not None:
# update the dict
if self.alert_to_webhook_url is None:
self.alert_to_webhook_url = alert_to_webhook_url
else:
self.alert_to_webhook_url.update(alert_to_webhook_url)
async def deployment_in_cooldown(self):
pass
async def deployment_removed_from_cooldown(self):
pass
def _all_possible_alert_types(self):
# used by the UI to show all supported alert types
# Note: This is not the alerts the user has configured, instead it's all possible alert types a user can select
return [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
def _add_langfuse_trace_id_to_alert(
self,
request_info: str,
request_data: Optional[dict] = None,
kwargs: Optional[dict] = None,
):
import uuid
# For now: do nothing as we're debugging why this is not working as expected
return request_info
# if request_data is not None:
# trace_id = request_data.get("metadata", {}).get(
# "trace_id", None
# ) # get langfuse trace id
# if trace_id is None:
# trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
# request_data["metadata"]["trace_id"] = trace_id
# elif kwargs is not None:
# _litellm_params = kwargs.get("litellm_params", {})
# trace_id = _litellm_params.get("metadata", {}).get(
# "trace_id", None
# ) # get langfuse trace id
# if trace_id is None:
# trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
# _litellm_params["metadata"]["trace_id"] = trace_id
# _langfuse_host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
# _langfuse_project_id = os.environ.get("LANGFUSE_PROJECT_ID")
# # langfuse urls look like: https://us.cloud.langfuse.com/project/************/traces/litellm-alert-trace-ididi9dk-09292-************
# _langfuse_url = (
# f"{_langfuse_host}/project/{_langfuse_project_id}/traces/{trace_id}"
# )
# request_info += f"\n🪢 Langfuse Trace: {_langfuse_url}"
# return request_info
def _response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
start_time,
end_time, # start/end time
):
try:
time_difference = end_time - start_time
# Convert the timedelta to float (in seconds)
time_difference_float = time_difference.total_seconds()
litellm_params = kwargs.get("litellm_params", {})
model = kwargs.get("model", "")
api_base = litellm.get_api_base(model=model, optional_params=litellm_params)
messages = kwargs.get("messages", None)
# if messages does not exist fallback to "input"
if messages is None:
messages = kwargs.get("input", None)
# only use first 100 chars for alerting
_messages = str(messages)[:100]
return time_difference_float, model, api_base, _messages
except Exception as e:
raise e
def _get_deployment_latencies_to_alert(self, metadata=None):
if metadata is None:
return None
if "_latency_per_deployment" in metadata:
# Translate model_id to -> api_base
# _latency_per_deployment is a dictionary that looks like this:
"""
_latency_per_deployment: {
api_base: 0.01336697916666667
}
"""
_message_to_send = ""
_deployment_latencies = metadata["_latency_per_deployment"]
if len(_deployment_latencies) == 0:
return None
for api_base, latency in _deployment_latencies.items():
_message_to_send += f"\n{api_base}: {round(latency,2)}s"
_message_to_send = "```" + _message_to_send + "```"
return _message_to_send
async def response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
if self.alerting is None or self.alert_types is None:
return
time_difference_float, model, api_base, messages = (
self._response_taking_too_long_callback(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
)
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold:
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, kwargs=kwargs
)
# add deployment latencies to alert
if (
kwargs is not None
and "litellm_params" in kwargs
and "metadata" in kwargs["litellm_params"]
):
_metadata = kwargs["litellm_params"]["metadata"]
_deployment_latency_map = self._get_deployment_latencies_to_alert(
metadata=_metadata
)
if _deployment_latency_map is not None:
request_info += (
f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
)
await self.send_alert(
message=slow_message + request_info,
level="Low",
alert_type="llm_too_slow",
)
async def log_failure_event(self, original_exception: Exception):
pass
async def response_taking_too_long(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
type: Literal["hanging_request", "slow_response"] = "hanging_request",
request_data: Optional[dict] = None,
):
if self.alerting is None or self.alert_types is None:
return
if request_data is not None:
model = request_data.get("model", "")
messages = request_data.get("messages", None)
if messages is None:
# if messages does not exist fallback to "input"
messages = request_data.get("input", None)
# try casting messages to str and get the first 100 characters, else mark as None
try:
messages = str(messages)
messages = messages[:100]
except:
messages = ""
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, request_data=request_data
)
else:
request_info = ""
if type == "hanging_request":
await asyncio.sleep(
self.alerting_threshold
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
if (
request_data is not None
and request_data.get("litellm_status", "") != "success"
and request_data.get("litellm_status", "") != "fail"
):
if request_data.get("deployment", None) is not None and isinstance(
request_data["deployment"], dict
):
_api_base = litellm.get_api_base(
model=model,
optional_params=request_data["deployment"].get(
"litellm_params", {}
),
)
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: {_api_base}"
elif request_data.get("metadata", None) is not None and isinstance(
request_data["metadata"], dict
):
# In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
# in that case we fallback to the api base set in the request metadata
_metadata = request_data["metadata"]
_api_base = _metadata.get("api_base", "")
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: `{_api_base}`"
# only alert hanging responses if they have not been marked as success
alerting_message = (
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
)
# add deployment latencies to alert
_deployment_latency_map = self._get_deployment_latencies_to_alert(
metadata=request_data.get("metadata", {})
)
if _deployment_latency_map is not None:
request_info += f"\nDeployment Latencies\n{_deployment_latency_map}"
await self.send_alert(
message=alerting_message + request_info,
level="Medium",
alert_type="llm_requests_hanging",
)
async def budget_alerts(
self,
type: Literal[
"token_budget",
"user_budget",
"user_and_proxy_budget",
"failed_budgets",
"failed_tracking",
"projected_limit_exceeded",
],
user_max_budget: float,
user_current_spend: float,
user_info=None,
error_message="",
):
if self.alerting is None or self.alert_types is None:
# do nothing if alerting is not switched on
return
if "budget_alerts" not in self.alert_types:
return
_id: str = "default_id" # used for caching
if type == "user_and_proxy_budget":
user_info = dict(user_info)
user_id = user_info["user_id"]
_id = user_id
max_budget = user_info["max_budget"]
spend = user_info["spend"]
user_email = user_info["user_email"]
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
elif type == "token_budget":
token_info = dict(user_info)
token = token_info["token"]
_id = token
spend = token_info["spend"]
max_budget = token_info["max_budget"]
user_id = token_info["user_id"]
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
elif type == "failed_tracking":
user_id = str(user_info)
_id = user_id
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
message = "Failed Tracking Cost for" + user_info
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
return
elif type == "projected_limit_exceeded" and user_info is not None:
"""
Input variables:
user_info = {
"key_alias": key_alias,
"projected_spend": projected_spend,
"projected_exceeded_date": projected_exceeded_date,
}
user_max_budget=soft_limit,
user_current_spend=new_spend
"""
message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}"""
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
return
else:
user_info = str(user_info)
# percent of max_budget left to spend
if user_max_budget > 0:
percent_left = (user_max_budget - user_current_spend) / user_max_budget
else:
percent_left = 0
verbose_proxy_logger.debug(
f"Budget Alerts: Percent left: {percent_left} for {user_info}"
)
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
# - Alert once within 28d period
# - Cache this information
# - Don't re-alert, if alert already sent
_cache: DualCache = self.internal_usage_cache
# check if crossed budget
if user_current_spend >= user_max_budget:
verbose_proxy_logger.debug("Budget Crossed for %s", user_info)
message = "Budget Crossed for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
# check if 5% of max budget is left
if percent_left <= 0.05:
message = "5% budget left for" + user_info
cache_key = "alerting:{}".format(_id)
result = await _cache.async_get_cache(key=cache_key)
if result is None:
await self.send_alert(
message=message, level="Medium", alert_type="budget_alerts"
)
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
return
# check if 15% of max budget is left
if percent_left <= 0.15:
message = "15% budget left for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.send_alert(
message=message, level="Low", alert_type="budget_alerts"
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
return
async def send_alert(
self,
message: str,
level: Literal["Low", "Medium", "High"],
alert_type: Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
],
):
"""
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
- Responses taking too long
- Requests are hanging
- Calls are failing
- DB Read/Writes are failing
- Proxy Close to max budget
- Key Close to max budget
Parameters:
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
message: str - what is the alert about
"""
if self.alerting is None:
return
from datetime import datetime
import json
# Get the current timestamp
current_time = datetime.now().strftime("%H:%M:%S")
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
formatted_message = (
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
)
if _proxy_base_url is not None:
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
# check if we find the slack webhook url in self.alert_to_webhook_url
if (
self.alert_to_webhook_url is not None
and alert_type in self.alert_to_webhook_url
):
slack_webhook_url = self.alert_to_webhook_url[alert_type]
else:
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
if slack_webhook_url is None:
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
payload = {"text": formatted_message}
headers = {"Content-type": "application/json"}
response = await self.async_http_handler.post(
url=slack_webhook_url,
headers=headers,
data=json.dumps(payload),
)
if response.status_code == 200:
pass
else:
print("Error sending slack alert. Error=", response.text) # noqa

View file

@ -298,7 +298,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -2,18 +2,13 @@ import os, types
import json
from enum import Enum
import requests, copy
import time, uuid
import time
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import (
contains_tag,
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx
@ -21,6 +16,8 @@ class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
class AnthropicError(Exception):
def __init__(self, status_code, message):
@ -37,12 +34,14 @@ class AnthropicError(Exception):
class AnthropicConfig:
"""
Reference: https://docs.anthropic.com/claude/reference/complete_post
Reference: https://docs.anthropic.com/claude/reference/messages_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default
max_tokens: Optional[int] = (
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
)
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
@ -52,7 +51,9 @@ class AnthropicConfig:
def __init__(
self,
max_tokens: Optional[int] = 256, # anthropic requires a default
max_tokens: Optional[
int
] = 4096, # You can pass in a value yourself or use the default value 4096
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
@ -101,7 +102,222 @@ def validate_environment(api_key, user_headers):
return headers
class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
self,
model,
response,
model_response,
_is_function_call,
stream,
logging_obj,
api_key,
data,
messages,
print_verbose,
):
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
elif len(completion_response["content"]) == 0:
raise AnthropicError(
message="No content in response",
status_code=response.status_code,
)
else:
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
total_tokens = prompt_tokens + completion_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
model_response.usage = usage
return model_response
async def acompletion_stream_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data), stream=True
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return streamwrapper
async def acompletion_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
)
def completion(
self,
model: str,
messages: list,
api_base: str,
@ -112,13 +328,13 @@ def completion(
api_key,
logging_obj,
optional_params=None,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
json_schemas: dict = {}
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
@ -162,17 +378,15 @@ def completion(
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
headers["anthropic-beta"] = "tools-2024-04-04"
anthropic_tools = []
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
@ -193,11 +407,55 @@ def completion(
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion == True:
if (
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async anthropic streaming POST request")
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
return self.acompletion_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
if (
stream is not None and stream == True and _is_function_call == False
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose(f"makes anthropic streaming POST request")
print_verbose("makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
@ -211,136 +469,39 @@ def completion(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
completion_stream = response.iter_lines()
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return streaming_response
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
elif len(completion_response["content"]) == 0:
raise AnthropicError(
message="No content in response",
status_code=response.status_code,
)
else:
text_content = completion_response["content"][0].get("text", None)
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name
)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
text_content # allow user to access raw anthropic tool calling response
)
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call == True and stream is not None and stream == True:
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
return self.process_response(
model=model,
custom_llm_provider="cached_response",
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
total_tokens = prompt_tokens + completion_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding(self):
# logic for parsing in - calling - parsing out model embedding calls
pass
class ModelResponseIterator:
@ -367,8 +528,3 @@ class ModelResponseIterator:
raise StopAsyncIteration
self.is_done = True
return self.model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -4,10 +4,12 @@ from enum import Enum
import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
from .base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
class AnthropicConstants(Enum):
@ -94,10 +96,125 @@ def validate_environment(api_key, user_headers):
return headers
class AnthropicTextCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
):
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"]["content"] = (
completion_response["completion"]
)
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
) ##[TODO] use the anthropic tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the anthropic tokenizer here
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
async def async_completion(
self,
model: str,
model_response: ModelResponse,
api_base: str,
logging_obj,
encoding,
headers: dict,
data: dict,
client=None,
):
if client is None:
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=data["prompt"],
api_key=headers.get("x-api-key"),
original_response=response.text,
additional_args={"complete_input_dict": data},
)
response = self.process_response(
model_response=model_response,
response=response,
encoding=encoding,
prompt=data["prompt"],
model=model,
)
return response
async def async_streaming(
self,
model: str,
api_base: str,
logging_obj,
headers: dict,
data: Optional[dict],
client=None,
):
if client is None:
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic_text",
logging_obj=logging_obj,
)
return streamwrapper
def completion(
self,
model: str,
messages: list,
api_base: str,
acompletion: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
@ -108,6 +225,7 @@ def completion(
litellm_params=None,
logger_fn=None,
headers={},
client=None,
):
headers = validate_environment(api_key, headers)
if model in custom_prompt_dict:
@ -151,21 +269,53 @@ def completion(
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
if acompletion == True:
return self.async_streaming(
model=model,
api_base=api_base,
logging_obj=logging_obj,
headers=headers,
data=data,
client=None,
)
if client is None:
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
# stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic_text",
logging_obj=logging_obj,
)
return stream_response
elif acompletion == True:
return self.async_completion(
model=model,
model_response=model_response,
api_base=api_base,
logging_obj=logging_obj,
encoding=encoding,
headers=headers,
data=data,
client=client,
)
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if client is None:
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
@ -179,44 +329,16 @@ def completion(
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"]["content"] = (
completion_response["completion"]
)
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
) ##[TODO] use the anthropic tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the anthropic tokenizer here
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
response = self.process_response(
model_response=model_response,
response=response,
encoding=encoding,
prompt=data["prompt"],
model=model,
)
model_response.usage = usage
return model_response
return response
def embedding():
def embedding(self):
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -799,6 +799,7 @@ class AzureChatCompletion(BaseLLM):
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
@ -817,8 +818,6 @@ class AzureChatCompletion(BaseLLM):
"timeout": timeout,
}
max_retries = optional_params.pop("max_retries", None)
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)

View file

@ -8,6 +8,7 @@ from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
TranscriptionResponse,
TextCompletionResponse,
)
from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig
@ -15,11 +16,11 @@ import litellm, json
import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
from ..llms.openai import OpenAITextCompletion
from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
import uuid
from .prompt_templates.factory import prompt_factory, custom_prompt
openai_text_completion = OpenAITextCompletion()
openai_text_completion_config = OpenAITextCompletionConfig()
class AzureOpenAIError(Exception):
@ -300,10 +301,12 @@ class AzureTextCompletion(BaseLLM):
"api_base": api_base,
},
)
return openai_text_completion.convert_to_model_response_object(
response_object=stringified_response,
return (
openai_text_completion_config.convert_to_chat_model_response_object(
response_object=TextCompletionResponse(**stringified_response),
model_response_object=model_response,
)
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
@ -373,7 +376,7 @@ class AzureTextCompletion(BaseLLM):
},
)
response = await azure_client.completions.create(**data, timeout=timeout)
return openai_text_completion.convert_to_model_response_object(
return openai_text_completion_config.convert_to_chat_model_response_object(
response_object=response.model_dump(),
model_response_object=model_response,
)

View file

@ -55,9 +55,11 @@ def completion(
"inputs": prompt,
"prompt": prompt,
"parameters": optional_params,
"stream": True
"stream": (
True
if "stream" in optional_params and optional_params["stream"] == True
else False,
else False
),
}
## LOGGING
@ -71,9 +73,11 @@ def completion(
completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers,
data=json.dumps(data),
stream=True
stream=(
True
if "stream" in optional_params and optional_params["stream"] == True
else False,
else False
),
)
if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True
@ -102,28 +106,28 @@ def completion(
and "data" in completion_response["model_output"]
and isinstance(completion_response["model_output"]["data"], list)
):
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]["data"][0]
model_response["choices"][0]["message"]["content"] = (
completion_response["model_output"]["data"][0]
)
elif isinstance(completion_response["model_output"], str):
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]
model_response["choices"][0]["message"]["content"] = (
completion_response["model_output"]
)
elif "completion" in completion_response and isinstance(
completion_response["completion"], str
):
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
model_response["choices"][0]["message"]["content"] = (
completion_response["completion"]
)
elif isinstance(completion_response, list) and len(completion_response) > 0:
if "generated_text" not in completion_response:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code,
)
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
model_response["choices"][0]["message"]["content"] = (
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS
if (
"details" in completion_response[0]
@ -155,7 +159,8 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -653,6 +653,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
else:
prompt = ""
for message in messages:
@ -746,7 +750,7 @@ def completion(
]
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
model=model, messages=messages, custom_llm_provider="anthropic_xml"
)
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
@ -1008,7 +1012,7 @@ def completion(
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = model_response_iterator(
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
@ -1028,7 +1032,7 @@ def completion(
total_tokens=response_body["usage"]["input_tokens"]
+ response_body["usage"]["output_tokens"],
)
model_response.usage = _usage
setattr(model_response, "usage", _usage)
else:
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
@ -1071,8 +1075,10 @@ def completion(
status_code=response_metadata.get("HTTPStatusCode", 500),
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
if getattr(model_response.usage, "total_tokens", None) is None:
## CALCULATING USAGE - bedrock charges on time, not tokens - have some mapping of cost here.
if not hasattr(model_response, "usage"):
setattr(model_response, "usage", Usage())
if getattr(model_response.usage, "total_tokens", None) is None: # type: ignore
prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
)
@ -1089,7 +1095,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
model_response["created"] = int(time.time())
model_response["model"] = model
@ -1109,8 +1115,30 @@ def completion(
raise BedrockError(status_code=500, message=traceback.format_exc())
async def model_response_iterator(model_response):
yield model_response
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response
def _embedding_func_single(

View file

@ -167,7 +167,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -237,7 +237,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -43,6 +43,7 @@ class CohereChatConfig:
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
@ -62,6 +63,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(
self,
@ -82,6 +84,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -302,5 +305,5 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -0,0 +1,96 @@
import httpx, asyncio
from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
class AsyncHTTPHandler:
def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
):
# Create a client with a connection pool
self.client = httpx.AsyncClient(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
async def close(self):
# Close the client when you're done with it
await self.client.aclose()
async def __aenter__(self):
return self.client
async def __aexit__(self):
# close the client when exiting
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = await self.client.get(url, params=params, headers=headers)
return response
async def post(
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
)
response = await self.client.send(req, stream=stream)
return response
def __del__(self) -> None:
try:
asyncio.get_running_loop().create_task(self.close())
except Exception:
pass
class HTTPHandler:
def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
):
# Create a client with a connection pool
self.client = httpx.Client(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
def close(self):
# Close the client when you're done with it
self.client.close()
def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = self.client.get(url, params=params, headers=headers)
return response
def post(
self,
url: str,
data: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
response = self.client.post(url, data=data, params=params, headers=headers)
return response
def __del__(self) -> None:
try:
self.close()
except Exception:
pass

View file

@ -6,7 +6,8 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm
import sys, httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version
class GeminiError(Exception):
@ -103,6 +104,13 @@ class TextStreamer:
break
def supports_system_instruction():
import google.generativeai as genai
gemini_pkg_version = Version(genai.__version__)
return gemini_pkg_version >= Version("0.5.0")
def completion(
model: str,
messages: list,
@ -124,7 +132,7 @@ def completion(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
genai.configure(api_key=api_key)
system_prompt = ""
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -135,6 +143,7 @@ def completion(
messages=messages,
)
else:
system_prompt, messages = get_system_prompt(messages=messages)
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini"
)
@ -162,11 +171,20 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}},
additional_args={
"complete_input_dict": {
"inference_params": inference_params,
"system_prompt": system_prompt,
}
},
)
## COMPLETION CALL
try:
_model = genai.GenerativeModel(f"models/{model}")
_params = {"model_name": "models/{}".format(model)}
_system_instruction = supports_system_instruction()
if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params)
if stream == True:
if acompletion == True:
@ -213,11 +231,12 @@ def completion(
encoding=encoding,
)
else:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
)
params = {
"contents": prompt,
"generation_config": genai.types.GenerationConfig(**inference_params),
"safety_settings": safety_settings,
}
response = _model.generate_content(**params)
except Exception as e:
raise GeminiError(
message=str(e),
@ -292,7 +311,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -152,9 +152,9 @@ def completion(
else:
try:
if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response["answer"]
model_response["choices"][0]["message"]["content"] = (
completion_response["answer"]
)
except Exception as e:
raise MaritalkError(
message=response.text, status_code=response.status_code
@ -174,7 +174,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -185,9 +185,9 @@ def completion(
else:
try:
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
model_response["choices"][0]["message"]["content"] = (
completion_response["generated_text"]
)
except:
raise NLPCloudError(
message=json.dumps(completion_response),
@ -205,7 +205,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -20,7 +20,7 @@ class OllamaError(Exception):
class OllamaConfig:
"""
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
@ -69,7 +69,7 @@ class OllamaConfig:
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
@ -228,8 +228,8 @@ def get_ollama_response(
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore
completion_tokens = response_json["eval_count"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
@ -330,8 +330,8 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore
completion_tokens = response_json["eval_count"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,

View file

@ -20,7 +20,7 @@ class OllamaError(Exception):
class OllamaChatConfig:
"""
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
@ -69,7 +69,7 @@ class OllamaChatConfig:
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
stop: Optional[list] = (
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
)
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
@ -148,7 +148,7 @@ class OllamaChatConfig:
if param == "top_p":
optional_params["top_p"] = value
if param == "frequency_penalty":
optional_params["repeat_penalty"] = param
optional_params["repeat_penalty"] = value
if param == "stop":
optional_params["stop"] = value
if param == "response_format" and value["type"] == "json_object":
@ -184,6 +184,7 @@ class OllamaChatConfig:
# ollama implementation
def get_ollama_response(
api_base="http://localhost:11434",
api_key: Optional[str] = None,
model="llama2",
messages=None,
optional_params=None,
@ -236,6 +237,7 @@ def get_ollama_response(
if stream == True:
response = ollama_async_streaming(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
@ -244,6 +246,7 @@ def get_ollama_response(
else:
response = ollama_acompletion(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
@ -252,12 +255,17 @@ def get_ollama_response(
)
return response
elif stream == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}",
json=data,
return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj
)
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
response = requests.post(**_request) # type: ignore
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
@ -307,10 +315,16 @@ def get_ollama_response(
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
def ollama_completion_stream(url, api_key, data, logging_obj):
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
with httpx.stream(**_request) as response:
try:
if response.status_code != 200:
raise OllamaError(
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
async def ollama_async_streaming(
url, api_key, data, model_response, encoding, logging_obj
):
try:
client = httpx.AsyncClient()
async with client.stream(
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
async with client.stream(**_request) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code, message=response.text
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async def ollama_acompletion(
url, data, model_response, encoding, logging_obj, function_name
url,
api_key: Optional[str],
data,
model_response,
encoding,
logging_obj,
function_name,
):
data["stream"] = False
try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
resp = await session.post(**_request)
if resp.status != 200:
text = await resp.text()

View file

@ -99,9 +99,9 @@ def completion(
)
else:
try:
model_response["choices"][0]["message"][
"content"
] = completion_response["choices"][0]["message"]["content"]
model_response["choices"][0]["message"]["content"] = (
completion_response["choices"][0]["message"]["content"]
)
except:
raise OobaboogaError(
message=json.dumps(completion_response),
@ -115,7 +115,7 @@ def completion(
completion_tokens=completion_response["usage"]["completion_tokens"],
total_tokens=completion_response["usage"]["total_tokens"],
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

View file

@ -10,6 +10,7 @@ from litellm.utils import (
convert_to_model_response_object,
Usage,
TranscriptionResponse,
TextCompletionResponse,
)
from typing import Callable, Optional
import aiohttp, requests
@ -200,6 +201,43 @@ class OpenAITextCompletionConfig:
and v is not None
}
def convert_to_chat_model_response_object(
self,
response_object: Optional[TextCompletionResponse] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["text"],
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = (
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
)
return model_response_object
except Exception as e:
raise e
class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
@ -785,10 +823,10 @@ class OpenAIChatCompletion(BaseLLM):
optional_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
atranscription: bool = False,
):
@ -962,40 +1000,6 @@ class OpenAITextCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {api_key}"
return headers
def convert_to_model_response_object(
self,
response_object: Optional[dict] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant")
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
model_response_object.usage = response_object["usage"]
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = (
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
)
return model_response_object
except Exception as e:
raise e
def completion(
self,
model_response: ModelResponse,
@ -1010,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
optional_params=None,
litellm_params=None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
):
super().completion()
@ -1020,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
api_base = f"{api_base}/completions"
if (
len(messages) > 0
and "content" in messages[0]
@ -1029,12 +1033,12 @@ class OpenAITextCompletion(BaseLLM):
):
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
prompt = [message["content"] for message in messages] # type: ignore
# don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
## LOGGING
logging_obj.pre_call(
input=messages,
@ -1050,38 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
)
else:
response = httpx.post(
url=f"{api_base}", json=data, headers=headers, timeout=timeout
)
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
response = openai_client.completions.create(**data) # type: ignore
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
@ -1089,10 +1108,7 @@ class OpenAITextCompletion(BaseLLM):
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(
response_object=response.json(),
model_response_object=model_response,
)
return TextCompletionResponse(**response_json)
except Exception as e:
raise e
@ -1107,21 +1123,25 @@ class OpenAITextCompletion(BaseLLM):
api_key: str,
model: str,
timeout: float,
max_retries=None,
organization: Optional[str] = None,
client=None,
):
async with httpx.AsyncClient(timeout=timeout) as client:
try:
response = await client.post(
api_base,
json=data,
headers=headers,
timeout=litellm.request_timeout,
)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
response = await openai_aclient.completions.create(**data)
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
@ -1132,76 +1152,83 @@ class OpenAITextCompletion(BaseLLM):
"api_base": api_base,
},
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(
response_object=response_json, model_response_object=model_response
)
response_obj = TextCompletionResponse(**response_json)
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
raise e
def streaming(
self,
logging_obj,
api_base: str,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
):
with httpx.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
) as response:
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
response = openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response.iter_lines(),
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
for transformed_chunk in streamwrapper:
yield transformed_chunk
for chunk in streamwrapper:
yield chunk
async def async_streaming(
self,
logging_obj,
api_base: str,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
):
client = httpx.AsyncClient()
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
) as response:
try:
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
response = await openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response.aiter_lines(),
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e

View file

@ -191,7 +191,7 @@ def completion(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response

Some files were not shown because too many files have changed in this diff Show more