Merge branch 'BerriAI:main' into fix-anthropic-messages-api

This commit is contained in:
Emir Ayar 2024-04-27 11:50:04 +02:00 committed by GitHub
commit 38b5f34c77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
366 changed files with 73092 additions and 56717 deletions

View file

@ -8,6 +8,11 @@ jobs:
steps: steps:
- checkout - checkout
- run:
name: Show git commit hash
command: |
echo "Git commit hash: $CIRCLE_SHA1"
- run: - run:
name: Check if litellm dir was updated or if pyproject.toml was modified name: Check if litellm dir was updated or if pyproject.toml was modified
command: | command: |
@ -31,16 +36,17 @@ jobs:
pip install "google-generativeai==0.3.2" pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0" pip install "google-cloud-aiplatform==1.43.0"
pip install pyarrow pip install pyarrow
pip install "boto3>=1.28.57" pip install "boto3==1.34.34"
pip install "aioboto3>=12.3.0" pip install "aioboto3==12.3.0"
pip install langchain pip install langchain
pip install lunary==0.2.5 pip install lunary==0.2.5
pip install "langfuse>=2.0.0" pip install "langfuse==2.7.3"
pip install numpydoc pip install numpydoc
pip install traceloop-sdk==0.0.69 pip install traceloop-sdk==0.0.69
pip install openai pip install openai
pip install prisma pip install prisma
pip install "httpx==0.24.1" pip install "httpx==0.24.1"
pip install fastapi
pip install "gunicorn==21.2.0" pip install "gunicorn==21.2.0"
pip install "anyio==3.7.1" pip install "anyio==3.7.1"
pip install "aiodynamo==23.10.1" pip install "aiodynamo==23.10.1"
@ -51,6 +57,7 @@ jobs:
pip install "pytest-mock==3.12.0" pip install "pytest-mock==3.12.0"
pip install python-multipart pip install python-multipart
pip install google-cloud-aiplatform pip install google-cloud-aiplatform
pip install prometheus-client==0.20.0
- save_cache: - save_cache:
paths: paths:
- ./venv - ./venv
@ -73,7 +80,7 @@ jobs:
name: Linting Testing name: Linting Testing
command: | command: |
cd litellm cd litellm
python -m pip install types-requests types-setuptools types-redis python -m pip install types-requests types-setuptools types-redis types-PyYAML
if ! python -m mypy . --ignore-missing-imports; then if ! python -m mypy . --ignore-missing-imports; then
echo "mypy detected errors" echo "mypy detected errors"
exit 1 exit 1
@ -123,6 +130,7 @@ jobs:
build_and_test: build_and_test:
machine: machine:
image: ubuntu-2204:2023.10.1 image: ubuntu-2204:2023.10.1
resource_class: xlarge
working_directory: ~/project working_directory: ~/project
steps: steps:
- checkout - checkout
@ -182,12 +190,19 @@ jobs:
-p 4000:4000 \ -p 4000:4000 \
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \ -e DATABASE_URL=$PROXY_DOCKER_DB_URL \
-e AZURE_API_KEY=$AZURE_API_KEY \ -e AZURE_API_KEY=$AZURE_API_KEY \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AWS_REGION_NAME=$AWS_REGION_NAME \
-e OPENAI_API_KEY=$OPENAI_API_KEY \ -e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
-e LANGFUSE_PROJECT1_SECRET=$LANGFUSE_PROJECT1_SECRET \
-e LANGFUSE_PROJECT2_SECRET=$LANGFUSE_PROJECT2_SECRET \
--name my-app \ --name my-app \
-v $(pwd)/proxy_server_config.yaml:/app/config.yaml \ -v $(pwd)/proxy_server_config.yaml:/app/config.yaml \
my-app:latest \ my-app:latest \
@ -292,7 +307,7 @@ jobs:
-H "Accept: application/vnd.github.v3+json" \ -H "Accept: application/vnd.github.v3+json" \
-H "Authorization: Bearer $GITHUB_TOKEN" \ -H "Authorization: Bearer $GITHUB_TOKEN" \
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \ "https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}\"}}" -d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}\", \"commit_hash\":\"$CIRCLE_SHA1\"}}"
workflows: workflows:
version: 2 version: 2

View file

@ -3,11 +3,10 @@ openai
python-dotenv python-dotenv
tiktoken tiktoken
importlib_metadata importlib_metadata
baseten
cohere cohere
redis redis
anthropic anthropic
orjson orjson
pydantic pydantic==1.10.14
google-cloud-aiplatform==1.43.0 google-cloud-aiplatform==1.43.0
redisvl==0.0.7 # semantic caching redisvl==0.0.7 # semantic caching

View file

@ -1,5 +1,5 @@
/docs docs
/cookbook cookbook
/.circleci .circleci
/.github .github
/tests tests

View file

@ -5,6 +5,13 @@ on:
inputs: inputs:
tag: tag:
description: "The tag version you want to build" description: "The tag version you want to build"
release_type:
description: "The release type you want to build. Can be 'latest', 'stable', 'dev'"
type: string
default: "latest"
commit_hash:
description: "Commit hash"
required: true
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds. # Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
env: env:
@ -85,9 +92,9 @@ jobs:
- name: Build and push Docker image - name: Build and push Docker image
uses: docker/build-push-action@4976231911ebf5f32aad765192d35f942aa48cb8 uses: docker/build-push-action@4976231911ebf5f32aad765192d35f942aa48cb8
with: with:
context: . context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
push: true push: true
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-latest # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest' tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8 platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
@ -121,10 +128,10 @@ jobs:
- name: Build and push Database Docker image - name: Build and push Database Docker image
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
with: with:
context: . context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
file: Dockerfile.database file: Dockerfile.database
push: true push: true
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-latest tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
labels: ${{ steps.meta-database.outputs.labels }} labels: ${{ steps.meta-database.outputs.labels }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8 platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
@ -158,11 +165,10 @@ jobs:
- name: Build and push Database Docker image - name: Build and push Database Docker image
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
with: with:
context: . context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
file: ./litellm-js/spend-logs/Dockerfile file: ./litellm-js/spend-logs/Dockerfile
push: true push: true
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-latest tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
labels: ${{ steps.meta-spend-logs.outputs.labels }}
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8 platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
build-and-push-helm-chart: build-and-push-helm-chart:
@ -236,10 +242,13 @@ jobs:
with: with:
github-token: "${{ secrets.GITHUB_TOKEN }}" github-token: "${{ secrets.GITHUB_TOKEN }}"
script: | script: |
const commitHash = "${{ github.event.inputs.commit_hash}}";
console.log("Commit Hash:", commitHash); // Add this line for debugging
try { try {
const response = await github.rest.repos.createRelease({ const response = await github.rest.repos.createRelease({
draft: false, draft: false,
generate_release_notes: true, generate_release_notes: true,
target_commitish: commitHash,
name: process.env.RELEASE_TAG, name: process.env.RELEASE_TAG,
owner: context.repo.owner, owner: context.repo.owner,
prerelease: false, prerelease: false,
@ -288,4 +297,3 @@ jobs:
} }
] ]
}' $WEBHOOK_URL }' $WEBHOOK_URL

View file

@ -77,6 +77,9 @@ if __name__ == "__main__":
new_release_body = ( new_release_body = (
existing_release_body existing_release_body
+ "\n\n" + "\n\n"
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
+ "\n\n"
+ "## Load Test LiteLLM Proxy Results" + "## Load Test LiteLLM Proxy Results"
+ "\n\n" + "\n\n"
+ markdown_table + markdown_table

View file

@ -10,7 +10,7 @@ class MyUser(HttpUser):
def chat_completion(self): def chat_completion(self):
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer sk-gUvTeN9g0sgHBMf9HeCaqA", "Authorization": f"Bearer sk-S2-EZTUUDY0EmM6-Fy0Fyw",
# Include any additional headers you may need for authentication, etc. # Include any additional headers you may need for authentication, etc.
} }

6
.gitignore vendored
View file

@ -45,3 +45,9 @@ deploy/charts/litellm/charts/*
deploy/charts/*.tgz deploy/charts/*.tgz
litellm/proxy/vertex_key.json litellm/proxy/vertex_key.json
**/.vim/ **/.vim/
/node_modules
kub.yaml
loadtest_kub.yaml
litellm/proxy/_new_secret_config.yaml
litellm/proxy/_new_secret_config.yaml
litellm/proxy/_super_secret_config.yaml

View file

@ -70,5 +70,4 @@ EXPOSE 4000/tcp
ENTRYPOINT ["litellm"] ENTRYPOINT ["litellm"]
# Append "--detailed_debug" to the end of CMD to view detailed debug logs # Append "--detailed_debug" to the end of CMD to view detailed debug logs
# CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"] CMD ["--port", "4000"]
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]

View file

@ -5,7 +5,7 @@
<p align="center">Call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, etc.] <p align="center">Call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, etc.]
<br> <br>
</p> </p>
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4> <h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a> | <a href="https://docs.litellm.ai/docs/hosted" target="_blank"> Hosted Proxy (Preview)</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
<h4 align="center"> <h4 align="center">
<a href="https://pypi.org/project/litellm/" target="_blank"> <a href="https://pypi.org/project/litellm/" target="_blank">
<img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version"> <img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version">
@ -32,9 +32,9 @@ LiteLLM manages:
- Set Budgets & Rate limits per project, api key, model [OpenAI Proxy Server](https://docs.litellm.ai/docs/simple_proxy) - Set Budgets & Rate limits per project, api key, model [OpenAI Proxy Server](https://docs.litellm.ai/docs/simple_proxy)
[**Jump to OpenAI Proxy Docs**](https://github.com/BerriAI/litellm?tab=readme-ov-file#openai-proxy---docs) <br> [**Jump to OpenAI Proxy Docs**](https://github.com/BerriAI/litellm?tab=readme-ov-file#openai-proxy---docs) <br>
[**Jump to Supported LLM Providers**](https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-provider-docs) [**Jump to Supported LLM Providers**](https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-providers-docs)
🚨 **Stable Release:** v1.34.1 🚨 **Stable Release:** Use docker images with: `main-stable` tag. These run through 12 hr load tests (1k req./min).
Support for more providers. Missing a provider or LLM Platform, raise a [feature request](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+). Support for more providers. Missing a provider or LLM Platform, raise a [feature request](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+).
@ -128,7 +128,9 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
# OpenAI Proxy - ([Docs](https://docs.litellm.ai/docs/simple_proxy)) # OpenAI Proxy - ([Docs](https://docs.litellm.ai/docs/simple_proxy))
Set Budgets & Rate limits across multiple projects Track spend + Load Balance across multiple projects
[Hosted Proxy (Preview)](https://docs.litellm.ai/docs/hosted)
The proxy provides: The proxy provides:
@ -205,7 +207,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
| [aws - bedrock](https://docs.litellm.ai/docs/providers/bedrock) | ✅ | ✅ | ✅ | ✅ | ✅ | | [aws - bedrock](https://docs.litellm.ai/docs/providers/bedrock) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [google - vertex_ai [Gemini]](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ | | [google - vertex_ai [Gemini]](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ | | [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
| [google AI Studio - gemini](https://docs.litellm.ai/docs/providers/gemini) | ✅ | | ✅ | | | | [google AI Studio - gemini](https://docs.litellm.ai/docs/providers/gemini) | ✅ | | ✅ | | |
| [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ | ✅ | | [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [cloudflare AI Workers](https://docs.litellm.ai/docs/providers/cloudflare_workers) | ✅ | ✅ | ✅ | ✅ | | [cloudflare AI Workers](https://docs.litellm.ai/docs/providers/cloudflare_workers) | ✅ | ✅ | ✅ | ✅ |
| [cohere](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | ✅ | ✅ | | [cohere](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | ✅ | ✅ |
@ -220,7 +222,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
| [nlp_cloud](https://docs.litellm.ai/docs/providers/nlp_cloud) | ✅ | ✅ | ✅ | ✅ | | [nlp_cloud](https://docs.litellm.ai/docs/providers/nlp_cloud) | ✅ | ✅ | ✅ | ✅ |
| [aleph alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ✅ | ✅ | ✅ | ✅ | | [aleph alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ✅ | ✅ | ✅ | ✅ |
| [petals](https://docs.litellm.ai/docs/providers/petals) | ✅ | ✅ | ✅ | ✅ | | [petals](https://docs.litellm.ai/docs/providers/petals) | ✅ | ✅ | ✅ | ✅ |
| [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ | | [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ | | [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ | | [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ | | [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |

204
cookbook/Proxy_Batch_Users.ipynb vendored Normal file
View file

@ -0,0 +1,204 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "680oRk1af-xJ"
},
"source": [
"# Environment Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X7TgJFn8f88p"
},
"outputs": [],
"source": [
"import csv\n",
"from typing import Optional\n",
"import httpx, json\n",
"import asyncio\n",
"\n",
"proxy_base_url = \"http://0.0.0.0:4000\" # 👈 SET TO PROXY URL\n",
"master_key = \"sk-1234\" # 👈 SET TO PROXY MASTER KEY"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rauw8EOhgBz5"
},
"outputs": [],
"source": [
"## GLOBAL HTTP CLIENT ## - faster http calls\n",
"class HTTPHandler:\n",
" def __init__(self, concurrent_limit=1000):\n",
" # Create a client with a connection pool\n",
" self.client = httpx.AsyncClient(\n",
" limits=httpx.Limits(\n",
" max_connections=concurrent_limit,\n",
" max_keepalive_connections=concurrent_limit,\n",
" )\n",
" )\n",
"\n",
" async def close(self):\n",
" # Close the client when you're done with it\n",
" await self.client.aclose()\n",
"\n",
" async def get(\n",
" self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None\n",
" ):\n",
" response = await self.client.get(url, params=params, headers=headers)\n",
" return response\n",
"\n",
" async def post(\n",
" self,\n",
" url: str,\n",
" data: Optional[dict] = None,\n",
" params: Optional[dict] = None,\n",
" headers: Optional[dict] = None,\n",
" ):\n",
" try:\n",
" response = await self.client.post(\n",
" url, data=data, params=params, headers=headers\n",
" )\n",
" return response\n",
" except Exception as e:\n",
" raise e\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7LXN8zaLgOie"
},
"source": [
"# Import Sheet\n",
"\n",
"\n",
"Format: | ID | Name | Max Budget |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oiED0usegPGf"
},
"outputs": [],
"source": [
"async def import_sheet():\n",
" tasks = []\n",
" http_client = HTTPHandler()\n",
" with open('my-batch-sheet.csv', 'r') as file:\n",
" csv_reader = csv.DictReader(file)\n",
" for row in csv_reader:\n",
" task = create_user(client=http_client, user_id=row['ID'], max_budget=row['Max Budget'], user_name=row['Name'])\n",
" tasks.append(task)\n",
" # print(f\"ID: {row['ID']}, Name: {row['Name']}, Max Budget: {row['Max Budget']}\")\n",
"\n",
" keys = await asyncio.gather(*tasks)\n",
"\n",
" with open('my-batch-sheet_new.csv', 'w', newline='') as new_file:\n",
" fieldnames = ['ID', 'Name', 'Max Budget', 'keys']\n",
" csv_writer = csv.DictWriter(new_file, fieldnames=fieldnames)\n",
" csv_writer.writeheader()\n",
"\n",
" with open('my-batch-sheet.csv', 'r') as file:\n",
" csv_reader = csv.DictReader(file)\n",
" for i, row in enumerate(csv_reader):\n",
" row['keys'] = keys[i] # Add the 'keys' value from the corresponding task result\n",
" csv_writer.writerow(row)\n",
"\n",
" await http_client.close()\n",
"\n",
"asyncio.run(import_sheet())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E7M0Li_UgJeZ"
},
"source": [
"# Create Users + Keys\n",
"\n",
"- Creates a user\n",
"- Creates a key with max budget"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NZudRFujf7j-"
},
"outputs": [],
"source": [
"\n",
"async def create_key_with_alias(client: HTTPHandler, user_id: str, max_budget: float):\n",
" global proxy_base_url\n",
" if not proxy_base_url.endswith(\"/\"):\n",
" proxy_base_url += \"/\"\n",
" url = proxy_base_url + \"key/generate\"\n",
"\n",
" # call /key/generate\n",
" print(\"CALLING /KEY/GENERATE\")\n",
" response = await client.post(\n",
" url=url,\n",
" headers={\"Authorization\": f\"Bearer {master_key}\"},\n",
" data=json.dumps({\n",
" \"user_id\": user_id,\n",
" \"key_alias\": f\"{user_id}-key\",\n",
" \"max_budget\": max_budget # 👈 KEY CHANGE: SETS MAX BUDGET PER KEY\n",
" })\n",
" )\n",
" print(f\"response: {response.text}\")\n",
" return response.json()[\"key\"]\n",
"\n",
"async def create_user(client: HTTPHandler, user_id: str, max_budget: float, user_name: str):\n",
" \"\"\"\n",
" - call /user/new\n",
" - create key for user\n",
" \"\"\"\n",
" global proxy_base_url\n",
" if not proxy_base_url.endswith(\"/\"):\n",
" proxy_base_url += \"/\"\n",
" url = proxy_base_url + \"user/new\"\n",
"\n",
" # call /user/new\n",
" await client.post(\n",
" url=url,\n",
" headers={\"Authorization\": f\"Bearer {master_key}\"},\n",
" data=json.dumps({\n",
" \"user_id\": user_id,\n",
" \"user_alias\": user_name,\n",
" \"auto_create_key\": False,\n",
" # \"max_budget\": max_budget # 👈 [OPTIONAL] Sets max budget per user (if you want to set a max budget across keys)\n",
" })\n",
" )\n",
"\n",
" # create key for user\n",
" return await create_key_with_alias(client=client, user_id=user_id, max_budget=max_budget)\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -87,6 +87,7 @@
| command-light | cohere | 0.00003 | | command-light | cohere | 0.00003 |
| command-medium-beta | cohere | 0.00003 | | command-medium-beta | cohere | 0.00003 |
| command-xlarge-beta | cohere | 0.00003 | | command-xlarge-beta | cohere | 0.00003 |
| command-r-plus| cohere | 0.000018 |
| j2-ultra | ai21 | 0.00003 | | j2-ultra | ai21 | 0.00003 |
| ai21.j2-ultra-v1 | bedrock | 0.0000376 | | ai21.j2-ultra-v1 | bedrock | 0.0000376 |
| gpt-4-1106-preview | openai | 0.00004 | | gpt-4-1106-preview | openai | 0.00004 |

73
cookbook/misc/config.yaml Normal file
View file

@ -0,0 +1,73 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- model_name: gpt-3.5-turbo-large
litellm_params:
model: "gpt-3.5-turbo-1106"
api_key: os.environ/OPENAI_API_KEY
rpm: 480
timeout: 300
stream_timeout: 60
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
rpm: 480
timeout: 300
stream_timeout: 60
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
input_cost_per_second: 0.000420
- model_name: text-embedding-ada-002
litellm_params:
model: azure/azure-embedding-model
api_key: os.environ/AZURE_API_KEY
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
model_info:
mode: embedding
base_model: text-embedding-ada-002
- model_name: dall-e-2
litellm_params:
model: azure/
api_version: 2023-06-01-preview
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_key: os.environ/AZURE_API_KEY
- model_name: openai-dall-e-3
litellm_params:
model: dall-e-3
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
drop_params: True
# max_budget: 100
# budget_duration: 30d
num_retries: 5
request_timeout: 600
telemetry: False
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
store_model_in_db: True
proxy_budget_rescheduler_min_time: 60
proxy_budget_rescheduler_max_time: 64
proxy_batch_write_at: 1
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
# environment_variables:
# settings for using redis caching
# REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
# REDIS_PORT: "16337"
# REDIS_PASSWORD:

View file

@ -0,0 +1,92 @@
"""
LiteLLM Migration Script!
Takes a config.yaml and calls /model/new
Inputs:
- File path to config.yaml
- Proxy base url to your hosted proxy
Step 1: Reads your config.yaml
Step 2: reads `model_list` and loops through all models
Step 3: calls `<proxy-base-url>/model/new` for each model
"""
import yaml
import requests
_in_memory_os_variables = {}
def migrate_models(config_file, proxy_base_url):
# Step 1: Read the config.yaml file
with open(config_file, "r") as f:
config = yaml.safe_load(f)
# Step 2: Read the model_list and loop through all models
model_list = config.get("model_list", [])
print("model_list: ", model_list)
for model in model_list:
model_name = model.get("model_name")
print("\nAdding model: ", model_name)
litellm_params = model.get("litellm_params", {})
api_base = litellm_params.get("api_base", "")
print("api_base on config.yaml: ", api_base)
litellm_model_name = litellm_params.get("model", "") or ""
if "vertex_ai/" in litellm_model_name:
print(f"\033[91m\nSkipping Vertex AI model\033[0m", model)
continue
for param, value in litellm_params.items():
if isinstance(value, str) and value.startswith("os.environ/"):
# check if value is in _in_memory_os_variables
if value in _in_memory_os_variables:
new_value = _in_memory_os_variables[value]
print(
"\033[92mAlready entered value for \033[0m",
value,
"\033[92musing \033[0m",
new_value,
)
else:
new_value = input(f"Enter value for {value}: ")
_in_memory_os_variables[value] = new_value
litellm_params[param] = new_value
print("\nlitellm_params: ", litellm_params)
# Confirm before sending POST request
confirm = input(
"\033[92mDo you want to send the POST request with the above parameters? (y/n): \033[0m"
)
if confirm.lower() != "y":
print("Aborting POST request.")
exit()
# Step 3: Call <proxy-base-url>/model/new for each model
url = f"{proxy_base_url}/model/new"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {master_key}",
}
data = {"model_name": model_name, "litellm_params": litellm_params}
print("POSTING data to proxy url", url)
response = requests.post(url, headers=headers, json=data)
if response.status_code != 200:
print(f"Error: {response.status_code} - {response.text}")
raise Exception(f"Error: {response.status_code} - {response.text}")
# Print the response for each model
print(
f"Response for model '{model_name}': Status Code:{response.status_code} - {response.text}"
)
# Usage
config_file = "config.yaml"
proxy_base_url = "http://0.0.0.0:4000"
master_key = "sk-1234"
print(f"config_file: {config_file}")
print(f"proxy_base_url: {proxy_base_url}")
migrate_models(config_file, proxy_base_url)

View file

@ -15,15 +15,16 @@ spec:
containers: containers:
- name: litellm-container - name: litellm-container
image: ghcr.io/berriai/litellm:main-latest image: ghcr.io/berriai/litellm:main-latest
imagePullPolicy: Always
env: env:
- name: AZURE_API_KEY - name: AZURE_API_KEY
value: "d6f****" value: "d6f****"
- name: AZURE_API_BASE - name: AZURE_API_BASE
value: "https://openai value: "https://openai"
- name: LITELLM_MASTER_KEY - name: LITELLM_MASTER_KEY
value: "sk-1234" value: "sk-1234"
- name: DATABASE_URL - name: DATABASE_URL
value: "postgresql://ishaan:*********"" value: "postgresql://ishaan*********"
args: args:
- "--config" - "--config"
- "/app/proxy_config.yaml" # Update the path to mount the config file - "/app/proxy_config.yaml" # Update the path to mount the config file

View file

@ -1,10 +1,16 @@
version: "3.9" version: "3.9"
services: services:
litellm: litellm:
build:
context: .
args:
target: runtime
image: ghcr.io/berriai/litellm:main-latest image: ghcr.io/berriai/litellm:main-latest
volumes:
- ./proxy_server_config.yaml:/app/proxy_server_config.yaml # mount your litellm config.yaml
ports: ports:
- "4000:4000" - "4000:4000" # Map the container port to the host, change the host port if necessary
environment: volumes:
- AZURE_API_KEY=sk-123 - ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8" ]
# ...rest of your docker-compose config if any

View file

@ -72,7 +72,7 @@ Here's the code for how we format all providers. Let us know how we can improve
| Anthropic | `claude-instant-1`, `claude-instant-1.2`, `claude-2` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/anthropic.py#L84) | Anthropic | `claude-instant-1`, `claude-instant-1.2`, `claude-2` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/anthropic.py#L84)
| OpenAI Text Completion | `text-davinci-003`, `text-curie-001`, `text-babbage-001`, `text-ada-001`, `babbage-002`, `davinci-002`, | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/main.py#L442) | OpenAI Text Completion | `text-davinci-003`, `text-curie-001`, `text-babbage-001`, `text-ada-001`, `babbage-002`, `davinci-002`, | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/main.py#L442)
| Replicate | all model names starting with `replicate/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/replicate.py#L180) | Replicate | all model names starting with `replicate/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/replicate.py#L180)
| Cohere | `command-nightly`, `command`, `command-light`, `command-medium-beta`, `command-xlarge-beta` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/cohere.py#L115) | Cohere | `command-nightly`, `command`, `command-light`, `command-medium-beta`, `command-xlarge-beta`, `command-r-plus` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/cohere.py#L115)
| Huggingface | all model names starting with `huggingface/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/huggingface_restapi.py#L186) | Huggingface | all model names starting with `huggingface/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/huggingface_restapi.py#L186)
| OpenRouter | all model names starting with `openrouter/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/main.py#L611) | OpenRouter | all model names starting with `openrouter/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/main.py#L611)
| AI21 | `j2-mid`, `j2-light`, `j2-ultra` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/ai21.py#L107) | AI21 | `j2-mid`, `j2-light`, `j2-ultra` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/ai21.py#L107)

View file

@ -0,0 +1,45 @@
# Using Vision Models
## Quick Start
Example passing images to a model
```python
import os
from litellm import completion
os.environ["OPENAI_API_KEY"] = "your-api-key"
# openai call
response = completion(
model = "gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
)
```
## Checking if a model supports `vision`
Use `litellm.supports_vision(model="")` -> returns `True` if model supports `vision` and `False` if not
```python
assert litellm.supports_vision(model="gpt-4-vision-preview") == True
assert litellm.supports_vision(model="gemini-1.0-pro-visionn") == True
assert litellm.supports_vision(model="gpt-3.5-turbo") == False
```

View file

@ -339,6 +339,8 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` | | textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
| textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` | | textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` |
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` | | textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
## Voyage AI Embedding Models ## Voyage AI Embedding Models

View file

@ -1,5 +1,5 @@
# Enterprise # Enterprise
For companies that need better security, user management and professional support For companies that need SSO, user management and professional support for LiteLLM Proxy
:::info :::info
@ -8,12 +8,13 @@ For companies that need better security, user management and professional suppor
::: :::
This covers: This covers:
- ✅ **Features under the [LiteLLM Commercial License](https://docs.litellm.ai/docs/proxy/enterprise):** - ✅ **Features under the [LiteLLM Commercial License (Content Mod, Custom Tags, etc.)](https://docs.litellm.ai/docs/proxy/enterprise)**
- ✅ **Feature Prioritization** - ✅ **Feature Prioritization**
- ✅ **Custom Integrations** - ✅ **Custom Integrations**
- ✅ **Professional Support - Dedicated discord + slack** - ✅ **Professional Support - Dedicated discord + slack**
- ✅ **Custom SLAs** - ✅ **Custom SLAs**
- ✅ **Secure access with Single Sign-On** - ✅ [**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui)
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
## Frequently Asked Questions ## Frequently Asked Questions

View file

@ -0,0 +1,49 @@
import Image from '@theme/IdealImage';
# Hosted LiteLLM Proxy
LiteLLM maintains the proxy, so you can focus on your core products.
## [**Get Onboarded**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
This is in alpha. Schedule a call with us, and we'll give you a hosted proxy within 30 minutes.
[**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
### **Status**: Alpha
Our proxy is already used in production by customers.
See our status page for [**live reliability**](https://status.litellm.ai/)
### **Benefits**
- **No Maintenance, No Infra**: We'll maintain the proxy, and spin up any additional infrastructure (e.g.: separate server for spend logs) to make sure you can load balance + track spend across multiple LLM projects.
- **Reliable**: Our hosted proxy is tested on 1k requests per second, making it reliable for high load.
- **Secure**: LiteLLM is currently undergoing SOC-2 compliance, to make sure your data is as secure as possible.
### Pricing
Pricing is based on usage. We can figure out a price that works for your team, on the call.
[**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
## **Screenshots**
### 1. Create keys
<Image img={require('../img/litellm_hosted_ui_create_key.png')} />
### 2. Add Models
<Image img={require('../img/litellm_hosted_ui_add_models.png')}/>
### 3. Track spend
<Image img={require('../img/litellm_hosted_usage_dashboard.png')} />
### 4. Configure load balancing
<Image img={require('../img/litellm_hosted_ui_router.png')} />
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)

View file

@ -8,6 +8,7 @@ liteLLM supports:
- [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback) - [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback)
- [Lunary](https://lunary.ai/docs) - [Lunary](https://lunary.ai/docs)
- [Langfuse](https://langfuse.com/docs)
- [Helicone](https://docs.helicone.ai/introduction) - [Helicone](https://docs.helicone.ai/introduction)
- [Traceloop](https://traceloop.com/docs) - [Traceloop](https://traceloop.com/docs)
- [Athina](https://docs.athina.ai/) - [Athina](https://docs.athina.ai/)
@ -22,8 +23,8 @@ from litellm import completion
# set callbacks # set callbacks
litellm.input_callback=["sentry"] # for sentry breadcrumbing - logs the input being sent to the api litellm.input_callback=["sentry"] # for sentry breadcrumbing - logs the input being sent to the api
litellm.success_callback=["posthog", "helicone", "lunary", "athina"] litellm.success_callback=["posthog", "helicone", "langfuse", "lunary", "athina"]
litellm.failure_callback=["sentry", "lunary"] litellm.failure_callback=["sentry", "lunary", "langfuse"]
## set env variables ## set env variables
os.environ['SENTRY_DSN'], os.environ['SENTRY_API_TRACE_RATE']= "" os.environ['SENTRY_DSN'], os.environ['SENTRY_API_TRACE_RATE']= ""
@ -32,6 +33,9 @@ os.environ["HELICONE_API_KEY"] = ""
os.environ["TRACELOOP_API_KEY"] = "" os.environ["TRACELOOP_API_KEY"] = ""
os.environ["LUNARY_PUBLIC_KEY"] = "" os.environ["LUNARY_PUBLIC_KEY"] = ""
os.environ["ATHINA_API_KEY"] = "" os.environ["ATHINA_API_KEY"] = ""
os.environ["LANGFUSE_PUBLIC_KEY"] = ""
os.environ["LANGFUSE_SECRET_KEY"] = ""
os.environ["LANGFUSE_HOST"] = ""
response = completion(model="gpt-3.5-turbo", messages=messages) response = completion(model="gpt-3.5-turbo", messages=messages)
``` ```

View file

@ -0,0 +1,68 @@
# Greenscale Tutorial
[Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII).
## Getting Started
Use Greenscale to log requests across all LLM Providers
liteLLM provides `callbacks`, making it easy for you to log data depending on the status of your responses.
## Using Callbacks
First, email `hello@greenscale.ai` to get an API_KEY.
Use just 1 line of code, to instantly log your responses **across all providers** with Greenscale:
```python
litellm.success_callback = ["greenscale"]
```
### Complete code
```python
from litellm import completion
## set env variables
os.environ['GREENSCALE_API_KEY'] = 'your-greenscale-api-key'
os.environ['GREENSCALE_ENDPOINT'] = 'greenscale-endpoint'
os.environ["OPENAI_API_KEY"]= ""
# set callback
litellm.success_callback = ["greenscale"]
#openai call
response = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}]
metadata={
"greenscale_project": "acme-project",
"greenscale_application": "acme-application"
}
)
```
## Additional information in metadata
You can send any additional information to Greenscale by using the `metadata` field in completion and `greenscale_` prefix. This can be useful for sending metadata about the request, such as the project and application name, customer_id, enviornment, or any other information you want to track usage. `greenscale_project` and `greenscale_application` are required fields.
```python
#openai call with additional metadata
response = completion(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Hi 👋 - i'm openai"}
],
metadata={
"greenscale_project": "acme-project",
"greenscale_application": "acme-application",
"greenscale_customer_id": "customer-123"
}
)
```
## Support & Talk with Greenscale Team
- [Schedule Demo 👋](https://calendly.com/nandesh/greenscale)
- [Website 💻](https://greenscale.ai)
- Our email ✉️ `hello@greenscale.ai`

View file

@ -57,7 +57,7 @@ os.environ["LANGSMITH_API_KEY"] = ""
os.environ['OPENAI_API_KEY']="" os.environ['OPENAI_API_KEY']=""
# set langfuse as a callback, litellm will send the data to langfuse # set langfuse as a callback, litellm will send the data to langfuse
litellm.success_callback = ["langfuse"] litellm.success_callback = ["langsmith"]
response = litellm.completion( response = litellm.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",

View file

@ -177,11 +177,7 @@ print(response)
:::info :::info
Claude returns it's output as an XML Tree. [Here is how we translate it](https://github.com/BerriAI/litellm/blob/49642a5b00a53b1babc1a753426a8afcac85dbbe/litellm/llms/prompt_templates/factory.py#L734). LiteLLM now uses Anthropic's 'tool' param 🎉 (v1.34.29+)
You can see the raw response via `response._hidden_params["original_response"]`.
Claude hallucinates, e.g. returning the list param `value` as `<value>\n<item>apple</item>\n<item>banana</item>\n</value>` or `<value>\n<list>\n<item>apple</item>\n<item>banana</item>\n</list>\n</value>`.
::: :::
```python ```python
@ -228,6 +224,91 @@ assert isinstance(
``` ```
### Parallel Function Calling
Here's how to pass the result of a function call back to an anthropic model:
```python
from litellm import completion
import os
os.environ["ANTHROPIC_API_KEY"] = "sk-ant.."
litellm.set_verbose = True
### 1ST FUNCTION CALL ###
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
### 2ND FUNCTION CALL ###
# In the second response, Claude should deduce answer from tool results
second_response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
)
print(second_response)
except Exception as e:
print(f"An error occurred - {str(e)}")
```
s/o @[Shekhar Patnaik](https://www.linkedin.com/in/patnaikshekhar) for requesting this!
## Usage - Vision ## Usage - Vision
```python ```python

View file

@ -1,55 +1,215 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Azure AI Studio # Azure AI Studio
## Using Mistral models deployed on Azure AI Studio **Ensure the following:**
1. The API Base passed ends in the `/v1/` prefix
example:
```python
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
```
### Sample Usage - setting env vars 2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
Set `MISTRAL_AZURE_API_KEY` and `MISTRAL_AZURE_API_BASE` in your env ## Usage
```shell <Tabs>
MISTRAL_AZURE_API_KEY = "zE************"" <TabItem value="sdk" label="SDK">
MISTRAL_AZURE_API_BASE = "https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1"
```python
import litellm
response = litellm.completion(
model="azure/command-r-plus",
api_base="<your-deployment-base>/v1/"
api_key="eskk******"
messages=[{"role": "user", "content": "What is the meaning of life?"}],
)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
## Sample Usage - LiteLLM Proxy
1. Add models to your config.yaml
```yaml
model_list:
- model_name: mistral
litellm_params:
model: azure/mistral-large-latest
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
api_key: JGbKodRcTp****
- model_name: command-r-plus
litellm_params:
model: azure/command-r-plus
api_key: os.environ/AZURE_COHERE_API_KEY
api_base: os.environ/AZURE_COHERE_API_BASE
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="mistral",
messages = [
{
"role": "user",
"content": "what llm are you"
}
],
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "mistral",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Function Calling
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
from litellm import completion from litellm import completion
import os
# set env
os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key"
os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion( response = completion(
model="mistral/Mistral-large-dfgfj", model="azure/mistral-large-latest",
messages=[ api_base=os.getenv("AZURE_MISTRAL_API_BASE")
{"role": "user", "content": "hello from litellm"} api_key=os.getenv("AZURE_MISTRAL_API_KEY")
], messages=messages,
tools=tools,
tool_choice="auto",
) )
# Add any assertions, here to check response args
print(response) print(response)
``` assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
### Sample Usage - passing `api_base` and `api_key` to `litellm.completion` response.choices[0].message.tool_calls[0].function.arguments, str
```python
from litellm import completion
import os
response = completion(
model="mistral/Mistral-large-dfgfj",
api_base="https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com",
api_key = "JGbKodRcTp****"
messages=[
{"role": "user", "content": "hello from litellm"}
],
) )
print(response)
``` ```
### [LiteLLM Proxy] Using Mistral Models </TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $YOUR_API_KEY" \
-d '{
"model": "mistral",
"messages": [
{
"role": "user",
"content": "What'\''s the weather like in Boston today?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
],
"tool_choice": "auto"
}'
Set this on your litellm proxy config.yaml
```yaml
model_list:
- model_name: mistral
litellm_params:
model: mistral/Mistral-large-dfgfj
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com
api_key: JGbKodRcTp****
``` ```
</TabItem>
</Tabs>
## Supported Models
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
| Cohere ommand-r | `completion(model="azure/command-r", messages)` |
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |

View file

@ -47,6 +47,7 @@ for chunk in response:
|------------|----------------| |------------|----------------|
| command-r | `completion('command-r', messages)` | | command-r | `completion('command-r', messages)` |
| command-light | `completion('command-light', messages)` | | command-light | `completion('command-light', messages)` |
| command-r-plus | `completion('command-r-plus', messages)` |
| command-medium | `completion('command-medium', messages)` | | command-medium | `completion('command-medium', messages)` |
| command-medium-beta | `completion('command-medium-beta', messages)` | | command-medium-beta | `completion('command-medium-beta', messages)` |
| command-xlarge-nightly | `completion('command-xlarge-nightly', messages)` | | command-xlarge-nightly | `completion('command-xlarge-nightly', messages)` |

View file

@ -23,7 +23,7 @@ In certain use-cases you may need to make calls to the models and pass [safety s
```python ```python
response = completion( response = completion(
model="gemini/gemini-pro", model="gemini/gemini-pro",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}] messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}],
safety_settings=[ safety_settings=[
{ {
"category": "HARM_CATEGORY_HARASSMENT", "category": "HARM_CATEGORY_HARASSMENT",
@ -95,9 +95,8 @@ print(content)
``` ```
## Chat Models ## Chat Models
| Model Name | Function Call | Required OS Variables | | Model Name | Function Call | Required OS Variables |
|------------------|--------------------------------------|-------------------------| |-----------------------|--------------------------------------------------------|--------------------------------|
| gemini-pro | `completion('gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-pro | `completion('gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro | `completion('gemini/gemini-1.5-pro', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-1.5-pro-latest | `completion('gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-pro-vision | `completion('gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` | | gemini-pro-vision | `completion('gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
| gemini-1.5-pro-vision | `completion('gemini/gemini-1.5-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |

View file

@ -48,6 +48,109 @@ We support ALL Groq models, just set `groq/` as a prefix when sending completion
| Model Name | Function Call | | Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| |--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| llama3-8b-8192 | `completion(model="groq/llama3-8b-8192", messages)` |
| llama3-70b-8192 | `completion(model="groq/llama3-70b-8192", messages)` |
| llama2-70b-4096 | `completion(model="groq/llama2-70b-4096", messages)` | | llama2-70b-4096 | `completion(model="groq/llama2-70b-4096", messages)` |
| mixtral-8x7b-32768 | `completion(model="groq/mixtral-8x7b-32768", messages)` | | mixtral-8x7b-32768 | `completion(model="groq/mixtral-8x7b-32768", messages)` |
| gemma-7b-it | `completion(model="groq/gemma-7b-it", messages)` | | gemma-7b-it | `completion(model="groq/gemma-7b-it", messages)` |
## Groq - Tool / Function Calling Example
```python
# Example dummy function hard coded to return the current weather
import json
def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
elif "san francisco" in location.lower():
return json.dumps(
{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}
)
elif "paris" in location.lower():
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
else:
return json.dumps({"location": location, "temperature": "unknown"})
# Step 1: send the conversation and available functions to the model
messages = [
{
"role": "system",
"content": "You are a function calling LLM that uses the data extracted from get_current_weather to answer questions about the weather in San Francisco.",
},
{
"role": "user",
"content": "What's the weather like in San Francisco?",
},
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
response = litellm.completion(
model="groq/llama2-70b-4096",
messages=messages,
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
)
print("Response\n", response)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
# Step 2: check if the model wanted to call a function
if tool_calls:
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = {
"get_current_weather": get_current_weather,
}
messages.append(
response_message
) # extend conversation with assistant's reply
print("Response message\n", response_message)
# Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(
location=function_args.get("location"),
unit=function_args.get("unit"),
)
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
print(f"messages: {messages}")
second_response = litellm.completion(
model="groq/llama2-70b-4096", messages=messages
) # get a new response from the model where it can see the function response
print("second response\n", second_response)
```

View file

@ -50,8 +50,53 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported.
| mistral-small | `completion(model="mistral/mistral-small", messages)` | | mistral-small | `completion(model="mistral/mistral-small", messages)` |
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` | | mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
| mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` | | mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` |
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
## Function Calling
```python
from litellm import completion
# set env
os.environ["MISTRAL_API_KEY"] = "your-api-key"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="mistral/mistral-large-latest",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
```
## Sample Usage - Embedding ## Sample Usage - Embedding
```python ```python
from litellm import embedding from litellm import embedding

View file

@ -1,5 +1,5 @@
# Ollama # Ollama
LiteLLM supports all models from [Ollama](https://github.com/jmorganca/ollama) LiteLLM supports all models from [Ollama](https://github.com/ollama/ollama)
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_Ollama.ipynb"> <a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_Ollama.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
@ -97,7 +97,7 @@ response = completion(
print(response) print(response)
``` ```
## Ollama Models ## Ollama Models
Ollama supported models: https://github.com/jmorganca/ollama Ollama supported models: https://github.com/ollama/ollama
| Model Name | Function Call | | Model Name | Function Call |
|----------------------|----------------------------------------------------------------------------------- |----------------------|-----------------------------------------------------------------------------------

View file

@ -1,5 +1,8 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# OpenAI # OpenAI
LiteLLM supports OpenAI Chat + Text completion and embedding calls. LiteLLM supports OpenAI Chat + Embedding calls.
### Required API Keys ### Required API Keys
@ -22,6 +25,132 @@ response = completion(
) )
``` ```
### Usage - LiteLLM Proxy Server
Here's how to call OpenAI models with the LiteLLM Proxy Server
### 1. Save key in your environment
```bash
export OPENAI_API_KEY=""
```
### 2. Start the proxy
<Tabs>
<TabItem value="config" label="config.yaml">
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/gpt-3.5-turbo # The `openai/` prefix will call openai.chat.completions.create
api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-3.5-turbo-instruct
litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct # The `text-completion-openai/` prefix will call openai.completions.create
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
<TabItem value="config-*" label="config.yaml - proxy all OpenAI models">
Use this to add all openai models with one API Key. **WARNING: This will not do any load balancing**
This means requests to `gpt-4`, `gpt-3.5-turbo` , `gpt-4-turbo-preview` will all go through this route
```yaml
model_list:
- model_name: "*" # all requests where model not in your config go to this deployment
litellm_params:
model: openai/* # set `openai/` to use the openai route
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
<TabItem value="cli" label="CLI">
```bash
$ litellm --model gpt-3.5-turbo
# Server running on http://0.0.0.0:4000
```
</TabItem>
</Tabs>
### 3. Test it
<Tabs>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
model = "gpt-3.5-turbo",
temperature=0.1
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
### Optional Keys - OpenAI Organization, OpenAI API Base ### Optional Keys - OpenAI Organization, OpenAI API Base
```python ```python
@ -34,6 +163,8 @@ os.environ["OPENAI_API_BASE"] = "openaiai-api-base" # OPTIONAL
| Model Name | Function Call | | Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------| |-----------------------|-----------------------------------------------------------------|
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` | | gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
| gpt-4-1106-preview | `response = completion(model="gpt-4-1106-preview", messages=messages)` | | gpt-4-1106-preview | `response = completion(model="gpt-4-1106-preview", messages=messages)` |
| gpt-3.5-turbo-1106 | `response = completion(model="gpt-3.5-turbo-1106", messages=messages)` | | gpt-3.5-turbo-1106 | `response = completion(model="gpt-3.5-turbo-1106", messages=messages)` |
@ -55,6 +186,7 @@ These also support the `OPENAI_API_BASE` environment variable, which can be used
## OpenAI Vision Models ## OpenAI Vision Models
| Model Name | Function Call | | Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------| |-----------------------|-----------------------------------------------------------------|
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` | | gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` |
#### Usage #### Usage
@ -88,19 +220,6 @@ response = completion(
``` ```
## OpenAI Text Completion Models / Instruct Models
| Model Name | Function Call |
|---------------------|----------------------------------------------------|
| gpt-3.5-turbo-instruct | `response = completion(model="gpt-3.5-turbo-instruct", messages=messages)` |
| gpt-3.5-turbo-instruct-0914 | `response = completion(model="gpt-3.5-turbo-instruct-091", messages=messages)` |
| text-davinci-003 | `response = completion(model="text-davinci-003", messages=messages)` |
| ada-001 | `response = completion(model="ada-001", messages=messages)` |
| curie-001 | `response = completion(model="curie-001", messages=messages)` |
| babbage-001 | `response = completion(model="babbage-001", messages=messages)` |
| babbage-002 | `response = completion(model="babbage-002", messages=messages)` |
| davinci-002 | `response = completion(model="davinci-002", messages=messages)` |
## Advanced ## Advanced
### Parallel Function calling ### Parallel Function calling

View file

@ -5,7 +5,9 @@ import TabItem from '@theme/TabItem';
To call models hosted behind an openai proxy, make 2 changes: To call models hosted behind an openai proxy, make 2 changes:
1. Put `openai/` in front of your model name, so litellm knows you're trying to call an openai-compatible endpoint. 1. For `/chat/completions`: Put `openai/` in front of your model name, so litellm knows you're trying to call an openai `/chat/completions` endpoint.
2. For `/completions`: Put `text-completion-openai/` in front of your model name, so litellm knows you're trying to call an openai `/completions` endpoint.
2. **Do NOT** add anything additional to the base url e.g. `/v1/embedding`. LiteLLM uses the openai-client to make these calls, and that automatically adds the relevant endpoints. 2. **Do NOT** add anything additional to the base url e.g. `/v1/embedding`. LiteLLM uses the openai-client to make these calls, and that automatically adds the relevant endpoints.

View file

@ -1,7 +1,16 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Replicate # Replicate
LiteLLM supports all models on Replicate LiteLLM supports all models on Replicate
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### API KEYS ### API KEYS
```python ```python
import os import os
@ -16,14 +25,175 @@ import os
## set ENV variables ## set ENV variables
os.environ["REPLICATE_API_KEY"] = "replicate key" os.environ["REPLICATE_API_KEY"] = "replicate key"
# replicate llama-2 call # replicate llama-3 call
response = completion( response = completion(
model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf", model="replicate/meta/meta-llama-3-8b-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}] messages = [{ "content": "Hello, how are you?","role": "user"}]
) )
``` ```
### Example - Calling Replicate Deployments </TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: replicate/meta/meta-llama-3-8b-instruct
api_key: os.environ/REPLICATE_API_KEY
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="llama-3",
messages = [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
]
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama-3",
"messages": [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
],
}'
```
</TabItem>
</Tabs>
### Expected Replicate Call
This is the call litellm will make to replicate, from the above example:
```bash
POST Request Sent from LiteLLM:
curl -X POST \
https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct \
-H 'Authorization: Token your-api-key' -H 'Content-Type: application/json' \
-d '{'version': 'meta/meta-llama-3-8b-instruct', 'input': {'prompt': '<|start_header_id|>system<|end_header_id|>\n\nBe a good human!<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat do you know about earth?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'}}'
```
</TabItem>
</Tabs>
## Advanced Usage - Prompt Formatting
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
To apply a custom prompt template:
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import litellm
import os
os.environ["REPLICATE_API_KEY"] = ""
# Create your own custom prompt template
litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K",
initial_prompt_value="You are a good assistant" # [OPTIONAL]
roles={
"system": {
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
},
"user": {
"pre_message": "[INST] ", # [OPTIONAL]
"post_message": " [/INST]" # [OPTIONAL]
},
"assistant": {
"pre_message": "\n" # [OPTIONAL]
"post_message": "\n" # [OPTIONAL]
}
}
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
)
def test_replicate_custom_model():
model = "replicate/togethercomputer/LLaMA-2-7B-32K"
response = completion(model=model, messages=messages)
print(response['choices'][0]['message']['content'])
return response
test_replicate_custom_model()
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
# Model-specific parameters
model_list:
- model_name: mistral-7b # model alias
litellm_params: # actual params for litellm.completion()
model: "replicate/mistralai/Mistral-7B-Instruct-v0.1"
api_key: os.environ/REPLICATE_API_KEY
initial_prompt_value: "\n"
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
final_prompt_value: "\n"
bos_token: "<s>"
eos_token: "</s>"
max_tokens: 4096
```
</TabItem>
</Tabs>
## Advanced Usage - Calling Replicate Deployments
Calling a [deployed replicate LLM](https://replicate.com/deployments) Calling a [deployed replicate LLM](https://replicate.com/deployments)
Add the `replicate/deployments/` prefix to your model, so litellm will call the `deployments` endpoint. This will call `ishaan-jaff/ishaan-mistral` deployment on replicate Add the `replicate/deployments/` prefix to your model, so litellm will call the `deployments` endpoint. This will call `ishaan-jaff/ishaan-mistral` deployment on replicate
@ -40,7 +210,7 @@ Replicate responses can take 3-5 mins due to replicate cold boots, if you're try
::: :::
### Replicate Models ## Replicate Models
liteLLM supports all replicate LLMs liteLLM supports all replicate LLMs
For replicate models ensure to add a `replicate/` prefix to the `model` arg. liteLLM detects it using this arg. For replicate models ensure to add a `replicate/` prefix to the `model` arg. liteLLM detects it using this arg.
@ -49,15 +219,15 @@ Below are examples on how to call replicate LLMs using liteLLM
Model Name | Function Call | Required OS Variables | Model Name | Function Call | Required OS Variables |
-----------------------------|----------------------------------------------------------------|--------------------------------------| -----------------------------|----------------------------------------------------------------|--------------------------------------|
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` | replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` |
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` | a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` |
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` | replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` | daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` | custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |
replicate deployment | `completion(model='replicate/deployments/ishaan-jaff/ishaan-mistral', messages)` | `os.environ['REPLICATE_API_KEY']` | replicate deployment | `completion(model='replicate/deployments/ishaan-jaff/ishaan-mistral', messages)` | `os.environ['REPLICATE_API_KEY']` |
### Passing additional params - max_tokens, temperature ## Passing additional params - max_tokens, temperature
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input) See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
```python ```python
@ -73,11 +243,22 @@ response = completion(
messages = [{ "content": "Hello, how are you?","role": "user"}], messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20, max_tokens=20,
temperature=0.5 temperature=0.5
) )
``` ```
### Passings Replicate specific params **proxy**
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: replicate/meta/meta-llama-3-8b-instruct
api_key: os.environ/REPLICATE_API_KEY
max_tokens: 20
temperature: 0.5
```
## Passings Replicate specific params
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Replicate by passing them to `litellm.completion` Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Replicate by passing them to `litellm.completion`
Example `seed`, `min_tokens` are Replicate specific param Example `seed`, `min_tokens` are Replicate specific param
@ -98,3 +279,15 @@ response = completion(
top_k=20, top_k=20,
) )
``` ```
**proxy**
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: replicate/meta/meta-llama-3-8b-instruct
api_key: os.environ/REPLICATE_API_KEY
min_tokens: 2
top_k: 20
```

View file

@ -0,0 +1,163 @@
# OpenAI (Text Completion)
LiteLLM supports OpenAI text completion models
### Required API Keys
```python
import os
os.environ["OPENAI_API_KEY"] = "your-api-key"
```
### Usage
```python
import os
from litellm import completion
os.environ["OPENAI_API_KEY"] = "your-api-key"
# openai call
response = completion(
model = "gpt-3.5-turbo-instruct",
messages=[{ "content": "Hello, how are you?","role": "user"}]
)
```
### Usage - LiteLLM Proxy Server
Here's how to call OpenAI models with the LiteLLM Proxy Server
### 1. Save key in your environment
```bash
export OPENAI_API_KEY=""
```
### 2. Start the proxy
<Tabs>
<TabItem value="config" label="config.yaml">
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/gpt-3.5-turbo # The `openai/` prefix will call openai.chat.completions.create
api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-3.5-turbo-instruct
litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct # The `text-completion-openai/` prefix will call openai.completions.create
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
<TabItem value="config-*" label="config.yaml - proxy all OpenAI models">
Use this to add all openai models with one API Key. **WARNING: This will not do any load balancing**
This means requests to `gpt-4`, `gpt-3.5-turbo` , `gpt-4-turbo-preview` will all go through this route
```yaml
model_list:
- model_name: "*" # all requests where model not in your config go to this deployment
litellm_params:
model: openai/* # set `openai/` to use the openai route
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
<TabItem value="cli" label="CLI">
```bash
$ litellm --model gpt-3.5-turbo-instruct
# Server running on http://0.0.0.0:4000
```
</TabItem>
</Tabs>
### 3. Test it
<Tabs>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo-instruct",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo-instruct", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
model = "gpt-3.5-turbo-instruct",
temperature=0.1
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
## OpenAI Text Completion Models / Instruct Models
| Model Name | Function Call |
|---------------------|----------------------------------------------------|
| gpt-3.5-turbo-instruct | `response = completion(model="gpt-3.5-turbo-instruct", messages=messages)` |
| gpt-3.5-turbo-instruct-0914 | `response = completion(model="gpt-3.5-turbo-instruct-0914", messages=messages)` |
| text-davinci-003 | `response = completion(model="text-davinci-003", messages=messages)` |
| ada-001 | `response = completion(model="ada-001", messages=messages)` |
| curie-001 | `response = completion(model="curie-001", messages=messages)` |
| babbage-001 | `response = completion(model="babbage-001", messages=messages)` |
| babbage-002 | `response = completion(model="babbage-002", messages=messages)` |
| davinci-002 | `response = completion(model="davinci-002", messages=messages)` |

View file

@ -1,18 +1,25 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# VertexAI - Google [Gemini, Model Garden] # VertexAI [Anthropic, Gemini, Model Garden]
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb"> <a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a> </a>
## Pre-requisites ## Pre-requisites
* `pip install google-cloud-aiplatform` * `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
* Authentication: * Authentication:
* run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc) * run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc)
* Alternatively you can set `application_default_credentials.json` * Alternatively you can set `GOOGLE_APPLICATION_CREDENTIALS`
Here's how: [**Jump to Code**](#extra)
- Create a service account on GCP
- Export the credentials as a json
- load the json and json.dump the json as a string
- store the json string in your environment as `GOOGLE_APPLICATION_CREDENTIALS`
## Sample Usage ## Sample Usage
```python ```python
@ -123,6 +130,100 @@ Here's how to use Vertex AI with the LiteLLM Proxy Server
</Tabs> </Tabs>
## Specifying Safety Settings
In certain use-cases you may need to make calls to the models and pass [safety settigns](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example:
<Tabs>
<TabItem value="sdk" label="SDK">
```python
response = completion(
model="gemini/gemini-pro",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
safety_settings=[
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**Option 1: Set in config**
```yaml
model_list:
- model_name: gemini-experimental
litellm_params:
model: vertex_ai/gemini-experimental
vertex_project: litellm-epic
vertex_location: us-central1
safety_settings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_NONE
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_NONE
```
**Option 2: Set on call**
```python
response = client.chat.completions.create(
model="gemini-experimental",
messages=[
{
"role": "user",
"content": "Can you write exploits?",
}
],
max_tokens=8192,
stream=False,
temperature=0.0,
extra_body={
"safety_settings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
],
}
)
```
</TabItem>
</Tabs>
## Set Vertex Project & Vertex Location ## Set Vertex Project & Vertex Location
All calls using Vertex AI require the following parameters: All calls using Vertex AI require the following parameters:
* Your Project ID * Your Project ID
@ -149,6 +250,85 @@ os.environ["VERTEXAI_LOCATION"] = "us-central1 # Your Location
# set directly on module # set directly on module
litellm.vertex_location = "us-central1 # Your Location litellm.vertex_location = "us-central1 # Your Location
``` ```
## Anthropic
| Model Name | Function Call |
|------------------|--------------------------------------|
| claude-3-opus@20240229 | `completion('vertex_ai/claude-3-opus@20240229', messages)` |
| claude-3-sonnet@20240229 | `completion('vertex_ai/claude-3-sonnet@20240229', messages)` |
| claude-3-haiku@20240307 | `completion('vertex_ai/claude-3-haiku@20240307', messages)` |
### Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""
model = "claude-3-sonnet@20240229"
vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"]
vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"]
response = completion(
model="vertex_ai/" + model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
)
print("\nModel Response", response)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: anthropic-vertex
litellm_params:
model: vertex_ai/claude-3-sonnet@20240229
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-east-1"
- model_name: anthropic-vertex
litellm_params:
model: vertex_ai/claude-3-sonnet@20240229
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-west-1"
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "anthropic-vertex", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```
</TabItem>
</Tabs>
## Model Garden ## Model Garden
| Model Name | Function Call | | Model Name | Function Call |
|------------------|--------------------------------------| |------------------|--------------------------------------|
@ -175,18 +355,15 @@ response = completion(
|------------------|--------------------------------------| |------------------|--------------------------------------|
| gemini-pro | `completion('gemini-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` | | gemini-pro | `completion('gemini-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
| Model Name | Function Call |
|------------------|--------------------------------------|
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
## Gemini Pro Vision ## Gemini Pro Vision
| Model Name | Function Call | | Model Name | Function Call |
|------------------|--------------------------------------| |------------------|--------------------------------------|
| gemini-pro-vision | `completion('gemini-pro-vision', messages)`, `completion('vertex_ai/gemini-pro-vision', messages)`| | gemini-pro-vision | `completion('gemini-pro-vision', messages)`, `completion('vertex_ai/gemini-pro-vision', messages)`|
## Gemini 1.5 Pro (and Vision)
| Model Name | Function Call | | Model Name | Function Call |
|------------------|--------------------------------------| |------------------|--------------------------------------|
| gemini-1.5-pro-vision | `completion('gemini-pro-vision', messages)`, `completion('vertex_ai/gemini-pro-vision', messages)`| | gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
@ -298,3 +475,75 @@ print(response)
| code-bison@001 | `completion('code-bison@001', messages)` | | code-bison@001 | `completion('code-bison@001', messages)` |
| code-gecko@001 | `completion('code-gecko@001', messages)` | | code-gecko@001 | `completion('code-gecko@001', messages)` |
| code-gecko@latest| `completion('code-gecko@latest', messages)` | | code-gecko@latest| `completion('code-gecko@latest', messages)` |
## Extra
### Using `GOOGLE_APPLICATION_CREDENTIALS`
Here's the code for storing your service account credentials as `GOOGLE_APPLICATION_CREDENTIALS` environment variable:
```python
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary file
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
```
### Using GCP Service Account
1. Figure out the Service Account bound to the Google Cloud Run service
<Image img={require('../../img/gcp_acc_1.png')} />
2. Get the FULL EMAIL address of the corresponding Service Account
3. Next, go to IAM & Admin > Manage Resources , select your top-level project that houses your Google Cloud Run Service
Click `Add Principal`
<Image img={require('../../img/gcp_acc_2.png')}/>
4. Specify the Service Account as the principal and Vertex AI User as the role
<Image img={require('../../img/gcp_acc_3.png')}/>
Once that's done, when you deploy the new container in the Google Cloud Run service, LiteLLM will have automatic access to all Vertex AI endpoints.
s/o @[Darien Kindlund](https://www.linkedin.com/in/kindlund/) for this tutorial

View file

@ -25,8 +25,11 @@ All models listed here https://docs.voyageai.com/embeddings/#models-and-specific
| Model Name | Function Call | | Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| |--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| voyage-2 | `embedding(model="voyage/voyage-2", input)` |
| voyage-large-2 | `embedding(model="voyage/voyage-large-2", input)` |
| voyage-law-2 | `embedding(model="voyage/voyage-law-2", input)` |
| voyage-code-2 | `embedding(model="voyage/voyage-code-2", input)` |
| voyage-lite-02-instruct | `embedding(model="voyage/voyage-lite-02-instruct", input)` |
| voyage-01 | `embedding(model="voyage/voyage-01", input)` | | voyage-01 | `embedding(model="voyage/voyage-01", input)` |
| voyage-lite-01 | `embedding(model="voyage/voyage-lite-01", input)` | | voyage-lite-01 | `embedding(model="voyage/voyage-lite-01", input)` |
| voyage-lite-01-instruct | `embedding(model="voyage/voyage-lite-01-instruct", input)` | | voyage-lite-01-instruct | `embedding(model="voyage/voyage-lite-01-instruct", input)` |

View file

@ -61,6 +61,22 @@ litellm_settings:
ttl: 600 # will be cached on redis for 600s ttl: 600 # will be cached on redis for 600s
``` ```
## SSL
just set `REDIS_SSL="True"` in your .env, and LiteLLM will pick this up.
```env
REDIS_SSL="True"
```
For quick testing, you can also use REDIS_URL, eg.:
```
REDIS_URL="rediss://.."
```
but we **don't** recommend using REDIS_URL in prod. We've noticed a performance difference between using it vs. redis_host, port, etc.
#### Step 2: Add Redis Credentials to .env #### Step 2: Add Redis Credentials to .env
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching. Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
@ -265,32 +281,6 @@ litellm_settings:
supported_call_types: ["acompletion", "completion", "embedding", "aembedding"] # defaults to all litellm call types supported_call_types: ["acompletion", "completion", "embedding", "aembedding"] # defaults to all litellm call types
``` ```
### Turn on `batch_redis_requests`
**What it does?**
When a request is made:
- Check if a key starting with `litellm:<hashed_api_key>:<call_type>:` exists in-memory, if no - get the last 100 cached requests for this key and store it
- New requests are stored with this `litellm:..` as the namespace
**Why?**
Reduce number of redis GET requests. This improved latency by 46% in prod load tests.
**Usage**
```yaml
litellm_settings:
cache: true
cache_params:
type: redis
... # remaining redis args (host, port, etc.)
callbacks: ["batch_redis_requests"] # 👈 KEY CHANGE!
```
[**SEE CODE**](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/batch_redis_get.py)
### Turn on / off caching per request. ### Turn on / off caching per request.
The proxy support 3 cache-controls: The proxy support 3 cache-controls:
@ -384,6 +374,87 @@ chat_completion = client.chat.completions.create(
) )
``` ```
### Deleting Cache Keys - `/cache/delete`
In order to delete a cache key, send a request to `/cache/delete` with the `keys` you want to delete
Example
```shell
curl -X POST "http://0.0.0.0:4000/cache/delete" \
-H "Authorization: Bearer sk-1234" \
-d '{"keys": ["586bf3f3c1bf5aecb55bd9996494d3bbc69eb58397163add6d49537762a7548d", "key2"]}'
```
```shell
# {"status":"success"}
```
#### Viewing Cache Keys from responses
You can view the cache_key in the response headers, on cache hits the cache key is sent as the `x-litellm-cache-key` response headers
```shell
curl -i --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"user": "ishan",
"messages": [
{
"role": "user",
"content": "what is litellm"
}
],
}'
```
Response from litellm proxy
```json
date: Thu, 04 Apr 2024 17:37:21 GMT
content-type: application/json
x-litellm-cache-key: 586bf3f3c1bf5aecb55bd9996494d3bbc69eb58397163add6d49537762a7548d
{
"id": "chatcmpl-9ALJTzsBlXR9zTxPvzfFFtFbFtG6T",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "I'm sorr.."
"role": "assistant"
}
}
],
"created": 1712252235,
}
```
### Turn on `batch_redis_requests`
**What it does?**
When a request is made:
- Check if a key starting with `litellm:<hashed_api_key>:<call_type>:` exists in-memory, if no - get the last 100 cached requests for this key and store it
- New requests are stored with this `litellm:..` as the namespace
**Why?**
Reduce number of redis GET requests. This improved latency by 46% in prod load tests.
**Usage**
```yaml
litellm_settings:
cache: true
cache_params:
type: redis
... # remaining redis args (host, port, etc.)
callbacks: ["batch_redis_requests"] # 👈 KEY CHANGE!
```
[**SEE CODE**](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/batch_redis_get.py)
## Supported `cache_params` on proxy config.yaml ## Supported `cache_params` on proxy config.yaml
```yaml ```yaml

View file

@ -600,6 +600,7 @@ general_settings:
"general_settings": { "general_settings": {
"completion_model": "string", "completion_model": "string",
"disable_spend_logs": "boolean", # turn off writing each transaction to the db "disable_spend_logs": "boolean", # turn off writing each transaction to the db
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
"disable_reset_budget": "boolean", # turn off reset budget scheduled task "disable_reset_budget": "boolean", # turn off reset budget scheduled task
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims "enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
"enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param "enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param

View file

@ -0,0 +1,9 @@
# 🎉 Demo App
Here is a demo of the proxy. To log in pass in:
- Username: admin
- Password: sk-1234
[Demo UI](https://demo.litellm.ai/ui)

View file

@ -231,13 +231,16 @@ Your OpenAI proxy server is now running on `http://127.0.0.1:4000`.
| Docs | When to Use | | Docs | When to Use |
| --- | --- | | --- | --- |
| [Quick Start](#quick-start) | call 100+ LLMs + Load Balancing | | [Quick Start](#quick-start) | call 100+ LLMs + Load Balancing |
| [Deploy with Database](#deploy-with-database) | + use Virtual Keys + Track Spend | | [Deploy with Database](#deploy-with-database) | + use Virtual Keys + Track Spend (Note: When deploying with a database providing a `DATABASE_URL` and `LITELLM_MASTER_KEY` are required in your env ) |
| [LiteLLM container + Redis](#litellm-container--redis) | + load balance across multiple litellm containers | | [LiteLLM container + Redis](#litellm-container--redis) | + load balance across multiple litellm containers |
| [LiteLLM Database container + PostgresDB + Redis](#litellm-database-container--postgresdb--redis) | + use Virtual Keys + Track Spend + load balance across multiple litellm containers | | [LiteLLM Database container + PostgresDB + Redis](#litellm-database-container--postgresdb--redis) | + use Virtual Keys + Track Spend + load balance across multiple litellm containers |
## Deploy with Database ## Deploy with Database
### Docker, Kubernetes, Helm Chart ### Docker, Kubernetes, Helm Chart
Requirements:
- Need a postgres database (e.g. [Supabase](https://supabase.com/), [Neon](https://neon.tech/), etc) Set `DATABASE_URL=postgresql://<user>:<password>@<host>:<port>/<dbname>` in your env
- Set a `LITELLM_MASTER_KEY`, this is your Proxy Admin key - you can use this to create other keys (🚨 must start with `sk-`)
<Tabs> <Tabs>
@ -246,12 +249,14 @@ Your OpenAI proxy server is now running on `http://127.0.0.1:4000`.
We maintain a [seperate Dockerfile](https://github.com/BerriAI/litellm/pkgs/container/litellm-database) for reducing build time when running LiteLLM proxy with a connected Postgres Database We maintain a [seperate Dockerfile](https://github.com/BerriAI/litellm/pkgs/container/litellm-database) for reducing build time when running LiteLLM proxy with a connected Postgres Database
```shell ```shell
docker pull docker pull ghcr.io/berriai/litellm-database:main-latest docker pull ghcr.io/berriai/litellm-database:main-latest
``` ```
```shell ```shell
docker run \ docker run \
-v $(pwd)/litellm_config.yaml:/app/config.yaml \ -v $(pwd)/litellm_config.yaml:/app/config.yaml \
-e LITELLM_MASTER_KEY=sk-1234 \
-e DATABASE_URL=postgresql://<user>:<password>@<host>:<port>/<dbname> \
-e AZURE_API_KEY=d6*********** \ -e AZURE_API_KEY=d6*********** \
-e AZURE_API_BASE=https://openai-***********/ \ -e AZURE_API_BASE=https://openai-***********/ \
-p 4000:4000 \ -p 4000:4000 \
@ -666,8 +671,8 @@ services:
litellm: litellm:
build: build:
context: . context: .
args: args:
target: runtime target: runtime
image: ghcr.io/berriai/litellm:main-latest image: ghcr.io/berriai/litellm:main-latest
ports: ports:
- "4000:4000" # Map the container port to the host, change the host port if necessary - "4000:4000" # Map the container port to the host, change the host port if necessary

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# ✨ Enterprise Features - Content Mod # ✨ Enterprise Features - Content Mod, SSO
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise) Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
@ -12,16 +12,18 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se
::: :::
Features: Features:
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
- ✅ Content Moderation with LLM Guard - ✅ Content Moderation with LLM Guard
- ✅ Content Moderation with LlamaGuard - ✅ Content Moderation with LlamaGuard
- ✅ Content Moderation with Google Text Moderations - ✅ Content Moderation with Google Text Moderations
- ✅ Reject calls from Blocked User list - ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
- ✅ Don't log/store specific requests (eg confidential LLM requests) - ✅ Don't log/store specific requests to Langfuse, Sentry, etc. (eg confidential LLM requests)
- ✅ Tracking Spend for Custom Tags - ✅ Tracking Spend for Custom Tags
## Content Moderation ## Content Moderation
### Content Moderation with LLM Guard ### Content Moderation with LLM Guard
@ -74,7 +76,7 @@ curl --location 'http://localhost:4000/key/generate' \
# Returns {..'key': 'my-new-key'} # Returns {..'key': 'my-new-key'}
``` ```
**2. Test it!** **3. Test it!**
```bash ```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
@ -87,6 +89,76 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
}' }'
``` ```
#### Turn on/off per request
**1. Update config**
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
llm_guard_mode: "request-specific"
```
**2. Create new key**
```bash
curl --location 'http://localhost:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"models": ["fake-openai-endpoint"],
}'
# Returns {..'key': 'my-new-key'}
```
**3. Test it!**
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
extra_body={ # pass in any provider-specific param, if not supported by openai, https://docs.litellm.ai/docs/completion/input#provider-specific-params
"metadata": {
"permissions": {
"enable_llm_guard_check": True # 👈 KEY CHANGE
},
}
}
)
print(response)
```
</TabItem>
<TabItem value="curl" label="Curl Request">
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
--data '{"model": "fake-openai-endpoint", "messages": [
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "What do you know?"}
]
}'
```
</TabItem>
</Tabs>
### Content Moderation with LlamaGuard ### Content Moderation with LlamaGuard

View file

@ -1,4 +1,4 @@
# Load Balancing - Config Setup # Multiple Instances
Load balance multiple instances of the same model Load balance multiple instances of the same model
The proxy will handle routing requests (using LiteLLM's Router). **Set `rpm` in the config if you want maximize throughput** The proxy will handle routing requests (using LiteLLM's Router). **Set `rpm` in the config if you want maximize throughput**
@ -10,75 +10,6 @@ For more details on routing strategies / params, see [Routing](../routing.md)
::: :::
## Quick Start - Load Balancing
### Step 1 - Set deployments on config
**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/<your-deployment-name>
api_base: <your-azure-endpoint>
api_key: <your-azure-api-key>
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 6
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-large
api_base: https://openai-france-1234.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 1440
```
### Step 2: Start Proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
### Step 3: Use proxy - Call a model group [Load Balancing]
Curl Command
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
### Usage - Call a specific model deployment
If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model`
In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "azure/gpt-turbo-small-ca",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
## Load Balancing using multiple litellm instances (Kubernetes, Auto Scaling) ## Load Balancing using multiple litellm instances (Kubernetes, Auto Scaling)
LiteLLM Proxy supports sharing rpm/tpm shared across multiple litellm instances, pass `redis_host`, `redis_password` and `redis_port` to enable this. (LiteLLM will use Redis to track rpm/tpm usage ) LiteLLM Proxy supports sharing rpm/tpm shared across multiple litellm instances, pass `redis_host`, `redis_password` and `redis_port` to enable this. (LiteLLM will use Redis to track rpm/tpm usage )

View file

@ -9,9 +9,9 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
- [Async Custom Callbacks](#custom-callback-class-async) - [Async Custom Callbacks](#custom-callback-class-async)
- [Async Custom Callback APIs](#custom-callback-apis-async) - [Async Custom Callback APIs](#custom-callback-apis-async)
- [Logging to DataDog](#logging-proxy-inputoutput---datadog)
- [Logging to Langfuse](#logging-proxy-inputoutput---langfuse) - [Logging to Langfuse](#logging-proxy-inputoutput---langfuse)
- [Logging to s3 Buckets](#logging-proxy-inputoutput---s3-buckets) - [Logging to s3 Buckets](#logging-proxy-inputoutput---s3-buckets)
- [Logging to DataDog](#logging-proxy-inputoutput---datadog)
- [Logging to DynamoDB](#logging-proxy-inputoutput---dynamodb) - [Logging to DynamoDB](#logging-proxy-inputoutput---dynamodb)
- [Logging to Sentry](#logging-proxy-inputoutput---sentry) - [Logging to Sentry](#logging-proxy-inputoutput---sentry)
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry) - [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
@ -539,6 +539,36 @@ print(response)
</Tabs> </Tabs>
### Team based Logging to Langfuse
**Example:**
This config would send langfuse logs to 2 different langfuse projects, based on the team id
```yaml
litellm_settings:
default_team_settings:
- team_id: my-secret-project
success_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
- team_id: ishaans-secret-project
success_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_2 # Project 2
langfuse_secret: os.environ/LANGFUSE_SECRET_2 # Project 2
```
Now, when you [generate keys](./virtual_keys.md) for this team-id
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"team_id": "ishaans-secret-project"}'
```
All requests made with these keys will log data to their team-specific logging.
## Logging Proxy Input/Output - DataDog ## Logging Proxy Input/Output - DataDog
We will use the `--config` to set `litellm.success_callback = ["datadog"]` this will log all successfull LLM calls to DataDog We will use the `--config` to set `litellm.success_callback = ["datadog"]` this will log all successfull LLM calls to DataDog

View file

@ -16,7 +16,7 @@ Expected Performance in Production
| `/chat/completions` Requests/hour | `126K` | | `/chat/completions` Requests/hour | `126K` |
## 1. Switch of Debug Logging ## 1. Switch off Debug Logging
Remove `set_verbose: True` from your config.yaml Remove `set_verbose: True` from your config.yaml
```yaml ```yaml
@ -40,7 +40,7 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"] CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
``` ```
## 2. Batch write spend updates every 60s ## 3. Batch write spend updates every 60s
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally. The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally.
@ -49,11 +49,35 @@ In production, we recommend using a longer interval period of 60s. This reduces
```yaml ```yaml
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds) proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
``` ```
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
## 3. Move spend logs to separate server When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
Recommended to do this for prod:
```yaml
router_settings:
routing_strategy: usage-based-routing-v2
# redis_url: "os.environ/REDIS_URL"
redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT
redis_password: os.environ/REDIS_PASSWORD
```
## 5. Switch off resetting budgets
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
```yaml
general_settings:
disable_reset_budget: true
```
## 6. Move spend logs to separate server (BETA)
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server. Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server.
@ -141,24 +165,6 @@ A t2.micro should be sufficient to handle 1k logs / minute on this server.
This consumes at max 120MB, and <0.1 vCPU. This consumes at max 120MB, and <0.1 vCPU.
## 4. Switch off resetting budgets
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
```yaml
general_settings:
disable_spend_logs: true
disable_reset_budget: true
```
## 5. Switch of `litellm.telemetry`
Switch of all telemetry tracking done by litellm
```yaml
litellm_settings:
telemetry: False
```
## Machine Specifications to Deploy LiteLLM ## Machine Specifications to Deploy LiteLLM
| Service | Spec | CPUs | Memory | Architecture | Version| | Service | Spec | CPUs | Memory | Architecture | Version|

View file

@ -14,6 +14,7 @@ model_list:
model: gpt-3.5-turbo model: gpt-3.5-turbo
litellm_settings: litellm_settings:
success_callback: ["prometheus"] success_callback: ["prometheus"]
failure_callback: ["prometheus"]
``` ```
Start the proxy Start the proxy
@ -48,6 +49,26 @@ http://localhost:4000/metrics
| Metric Name | Description | | Metric Name | Description |
|----------------------|--------------------------------------| |----------------------|--------------------------------------|
| `litellm_requests_metric` | Number of requests made, per `"user", "key", "model"` | | `litellm_requests_metric` | Number of requests made, per `"user", "key", "model", "team", "end-user"` |
| `litellm_spend_metric` | Total Spend, per `"user", "key", "model"` | | `litellm_spend_metric` | Total Spend, per `"user", "key", "model", "team", "end-user"` |
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model"` | | `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
## Monitor System Health
To monitor the health of litellm adjacent services (redis / postgres), do:
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
service_callback: ["prometheus_system"]
```
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_redis_latency` | histogram latency for redis calls |
| `litellm_redis_fails` | Number of failed redis calls |
| `litellm_self_latency` | Histogram latency for successful litellm api call |

View file

@ -348,6 +348,29 @@ query_result = embeddings.embed_query(text)
print(f"TITAN EMBEDDINGS") print(f"TITAN EMBEDDINGS")
print(query_result[:5]) print(query_result[:5])
```
</TabItem>
<TabItem value="litellm" label="LiteLLM SDK">
This is **not recommended**. There is duplicate logic as the proxy also uses the sdk, which might lead to unexpected errors.
```python
from litellm import completion
response = completion(
model="openai/gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
api_key="anything",
base_url="http://0.0.0.0:4000"
)
print(response)
``` ```
</TabItem> </TabItem>
</Tabs> </Tabs>
@ -438,7 +461,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
), ),
``` ```
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial. Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
</TabItem> </TabItem>
<TabItem value="aider" label="Aider"> <TabItem value="aider" label="Aider">
@ -551,4 +574,3 @@ No Logs
```shell ```shell
export LITELLM_LOG=None export LITELLM_LOG=None
``` ```

View file

@ -2,7 +2,9 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# Fallbacks, Retries, Timeouts, Cooldowns # 🔥 Fallbacks, Retries, Timeouts, Load Balancing
Retry call with multiple instances of the same model.
If a call fails after num_retries, fall back to another model group. If a call fails after num_retries, fall back to another model group.
@ -10,6 +12,77 @@ If the error is a context window exceeded error, fall back to a larger model gro
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py) [**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
## Quick Start - Load Balancing
### Step 1 - Set deployments on config
**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/<your-deployment-name>
api_base: <your-azure-endpoint>
api_key: <your-azure-api-key>
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 6
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-large
api_base: https://openai-france-1234.openai.azure.com/
api_key: <your-azure-api-key>
rpm: 1440
```
### Step 2: Start Proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
### Step 3: Use proxy - Call a model group [Load Balancing]
Curl Command
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
### Usage - Call a specific model deployment
If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model`
In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "azure/gpt-turbo-small-ca",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
## Fallbacks + Retries + Timeouts + Cooldowns
**Set via config** **Set via config**
```yaml ```yaml
model_list: model_list:
@ -63,7 +136,158 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
' '
``` ```
## Custom Timeouts, Stream Timeouts - Per Model ### Test it!
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data-raw '{
"model": "zephyr-beta", # 👈 MODEL NAME to fallback from
"messages": [
{"role": "user", "content": "what color is red"}
],
"mock_testing_fallbacks": true
}'
```
## Advanced - Context Window Fallbacks
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
[**See Code**](https://github.com/BerriAI/litellm/blob/c9e6b05cfb20dfb17272218e2555d6b496c47f6f/litellm/router.py#L2163)
**1. Setup config**
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with azure/.
<Tabs>
<TabItem value="same-group" label="Same Group">
Filter older instances of a model (e.g. gpt-3.5-turbo) with smaller context windows
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
text = "What is the meaning of 42?" * 5000
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
],
)
print(response)
```
</TabItem>
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
Fallback to larger models if current model is too small.
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo-small
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
- model_name: gpt-3.5-turbo-large
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
- model_name: claude-opus
litellm_params:
model: claude-3-opus-20240229
api_key: os.environ/ANTHROPIC_API_KEY
litellm_settings:
context_window_fallbacks: [{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}]
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
text = "What is the meaning of 42?" * 5000
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
],
)
print(response)
```
</TabItem>
</Tabs>
## Advanced - Custom Timeouts, Stream Timeouts - Per Model
For each model you can set `timeout` & `stream_timeout` under `litellm_params` For each model you can set `timeout` & `stream_timeout` under `litellm_params`
```yaml ```yaml
model_list: model_list:
@ -92,7 +316,7 @@ $ litellm --config /path/to/config.yaml
``` ```
## Setting Dynamic Timeouts - Per Request ## Advanced - Setting Dynamic Timeouts - Per Request
LiteLLM Proxy supports setting a `timeout` per request LiteLLM Proxy supports setting a `timeout` per request

View file

@ -99,7 +99,7 @@ Now, when you [generate keys](./virtual_keys.md) for this team-id
curl -X POST 'http://0.0.0.0:4000/key/generate' \ curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \ -H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-D '{"team_id": "ishaans-secret-project"}' -d '{"team_id": "ishaans-secret-project"}'
``` ```
All requests made with these keys will log data to their team-specific logging. All requests made with these keys will log data to their team-specific logging.

View file

@ -9,6 +9,7 @@ Use JWT's to auth admins / projects into the proxy.
This is a new feature, and subject to changes based on feedback. This is a new feature, and subject to changes based on feedback.
*UPDATE*: This will be moving to the [enterprise tier](./enterprise.md), once it's out of beta (~by end of April).
::: :::
## Usage ## Usage
@ -107,6 +108,34 @@ general_settings:
litellm_jwtauth: litellm_jwtauth:
admin_jwt_scope: "litellm-proxy-admin" admin_jwt_scope: "litellm-proxy-admin"
``` ```
## Advanced - Spend Tracking (User / Team / Org)
Set the field in the jwt token, which corresponds to a litellm user / team / org.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_jwtauth:
admin_jwt_scope: "litellm-proxy-admin"
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
```
Expected JWT:
```
{
"client_id": "my-unique-team",
"sub": "my-unique-user",
"org_id": "my-unique-org"
}
```
Now litellm will automatically update the spend for the user/team/org in the db for each call.
### JWT Scopes ### JWT Scopes
Here's what scopes on JWT-Auth tokens look like Here's what scopes on JWT-Auth tokens look like
@ -149,7 +178,7 @@ general_settings:
enable_jwt_auth: True enable_jwt_auth: True
litellm_jwtauth: litellm_jwtauth:
... ...
team_jwt_scope: "litellm-team" # 👈 Set JWT Scope string team_id_jwt_field: "litellm-team" # 👈 Set field in the JWT token that stores the team ID
team_allowed_routes: ["/v1/chat/completions"] # 👈 Set accepted routes team_allowed_routes: ["/v1/chat/completions"] # 👈 Set accepted routes
``` ```

View file

@ -56,6 +56,9 @@ On accessing the LiteLLM UI, you will be prompted to enter your username, passwo
## ✨ Enterprise Features ## ✨ Enterprise Features
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
### Setup SSO/Auth for UI ### Setup SSO/Auth for UI
#### Step 1: Set upperbounds for keys #### Step 1: Set upperbounds for keys

View file

@ -121,6 +121,9 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
) )
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
import os
os.environ["OPENAI_API_KEY"] = "anything"
chat = ChatOpenAI( chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000", openai_api_base="http://0.0.0.0:4000",

View file

@ -435,7 +435,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
), ),
``` ```
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial. Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
</TabItem> </TabItem>
<TabItem value="aider" label="Aider"> <TabItem value="aider" label="Aider">
@ -815,4 +815,3 @@ Thread Stats Avg Stdev Max +/- Stdev
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw) - [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238 - Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai - Our emails ✉️ ishaan@berri.ai / krrish@berri.ai

View file

@ -95,12 +95,129 @@ print(response)
- `router.image_generation()` - completion calls in OpenAI `/v1/images/generations` endpoint format - `router.image_generation()` - completion calls in OpenAI `/v1/images/generations` endpoint format
- `router.aimage_generation()` - async image generation calls - `router.aimage_generation()` - async image generation calls
### Advanced ## Advanced - Routing Strategies
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based #### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based
Router provides 4 strategies for routing your calls across multiple deployments: Router provides 4 strategies for routing your calls across multiple deployments:
<Tabs> <Tabs>
<TabItem value="usage-based-v2" label="Rate-Limit Aware v2 (ASYNC)">
**🎉 NEW** This is an async implementation of usage-based-routing.
**Filters out deployment if tpm/rpm limit exceeded** - If you pass in the deployment's tpm/rpm limits.
Routes to **deployment with lowest TPM usage** for that minute.
In production, we use Redis to track usage (TPM/RPM) across multiple deployments. This implementation uses **async redis calls** (redis.incr and redis.mget).
For Azure, your RPM = TPM/6.
<Tabs>
<TabItem value="sdk" label="sdk">
```python
from litellm import Router
model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 100000,
"rpm": 10000,
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 100000,
"rpm": 1000,
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 1000,
}]
router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
enable_pre_call_check=True, # enables router rate limits for concurrent calls
)
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}]
print(response)
```
</TabItem>
<TabItem value="proxy" label="proxy">
**1. Set strategy in config**
```yaml
model_list:
- model_name: gpt-3.5-turbo # model alias
litellm_params: # params for litellm completion/embedding call
model: azure/chatgpt-v-2 # actual model name
api_key: os.environ/AZURE_API_KEY
api_version: os.environ/AZURE_API_VERSION
api_base: os.environ/AZURE_API_BASE
tpm: 100000
rpm: 10000
- model_name: gpt-3.5-turbo
litellm_params: # params for litellm completion/embedding call
model: gpt-3.5-turbo
api_key: os.getenv(OPENAI_API_KEY)
tpm: 100000
rpm: 1000
router_settings:
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
redis_host: <your-redis-host>
redis_password: <your-redis-password>
redis_port: <your-redis-port>
enable_pre_call_check: true
general_settings:
master_key: sk-1234
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
```
**3. Test it!**
```bash
curl --location 'http://localhost:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="latency-based" label="Latency-Based"> <TabItem value="latency-based" label="Latency-Based">
@ -117,7 +234,10 @@ import asyncio
model_list = [{ ... }] model_list = [{ ... }]
# init router # init router
router = Router(model_list=model_list, routing_strategy="latency-based-routing") # 👈 set routing strategy router = Router(model_list=model_list,
routing_strategy="latency-based-routing",# 👈 set routing strategy
enable_pre_call_check=True, # enables router rate limits for concurrent calls
)
## CALL 1+2 ## CALL 1+2
tasks = [] tasks = []
@ -159,7 +279,7 @@ router_settings:
``` ```
</TabItem> </TabItem>
<TabItem value="simple-shuffle" label="(Default) Weighted Pick"> <TabItem value="simple-shuffle" label="(Default) Weighted Pick (Async)">
**Default** Picks a deployment based on the provided **Requests per minute (rpm) or Tokens per minute (tpm)** **Default** Picks a deployment based on the provided **Requests per minute (rpm) or Tokens per minute (tpm)**
@ -257,8 +377,9 @@ router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"], redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"], redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"], redis_port=os.environ["REDIS_PORT"],
routing_strategy="usage-based-routing") routing_strategy="usage-based-routing"
enable_pre_call_check=True, # enables router rate limits for concurrent calls
)
response = await router.acompletion(model="gpt-3.5-turbo", response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}] messages=[{"role": "user", "content": "Hey, how's it going?"}]
@ -555,7 +676,11 @@ router = Router(model_list: Optional[list] = None,
## Pre-Call Checks (Context Window) ## Pre-Call Checks (Context Window)
Enable pre-call checks to filter out deployments with context window limit < messages for a call. Enable pre-call checks to filter out:
1. deployments with context window limit < messages for a call.
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
])`)
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
@ -567,10 +692,14 @@ from litellm import Router
router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set to True router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set to True
``` ```
**2. (Azure-only) Set base model**
**2. Set Model List**
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`. For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
<Tabs>
<TabItem value="same-group" label="Same Group">
```python ```python
model_list = [ model_list = [
{ {
@ -582,7 +711,7 @@ model_list = [
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
}, },
"model_info": { "model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 SET BASE MODEL "base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
} }
}, },
{ {
@ -593,8 +722,51 @@ model_list = [
}, },
}, },
] ]
router = Router(model_list=model_list, enable_pre_call_checks=True)
``` ```
</TabItem>
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
```python
model_list = [
{
"model_name": "gpt-3.5-turbo-small", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
}
},
{
"model_name": "gpt-3.5-turbo-large", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "claude-opus",
"litellm_params": { call
"model": "claude-3-opus-20240229",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
```
</TabItem>
</Tabs>
**3. Test it!** **3. Test it!**
```python ```python
@ -646,60 +818,9 @@ print(f"response: {response}")
</TabItem> </TabItem>
<TabItem value="proxy" label="Proxy"> <TabItem value="proxy" label="Proxy">
**1. Setup config** :::info
Go [here](./proxy/reliability.md#advanced---context-window-fallbacks) for how to do this on the proxy
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with azure/. :::
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
model_info:
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
text = "What is the meaning of 42?" * 5000
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{"role": "system", "content": text},
{"role": "user", "content": "Who was Alexander?"},
],
)
print(response)
```
</TabItem> </TabItem>
</Tabs> </Tabs>

View file

@ -310,7 +310,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
), ),
``` ```
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial. Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
</TabItem> </TabItem>
<TabItem value="aider" label="Aider"> <TabItem value="aider" label="Aider">
@ -1351,5 +1351,3 @@ LiteLLM proxy adds **0.00325 seconds** latency as compared to using the Raw Open
```shell ```shell
litellm --telemetry False litellm --telemetry False
``` ```

View file

@ -105,6 +105,12 @@ const config = {
label: 'Enterprise', label: 'Enterprise',
to: "docs/enterprise" to: "docs/enterprise"
}, },
{
sidebarId: 'tutorialSidebar',
position: 'left',
label: '🚀 Hosted',
to: "docs/hosted"
},
{ {
href: 'https://github.com/BerriAI/litellm', href: 'https://github.com/BerriAI/litellm',
label: 'GitHub', label: 'GitHub',

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 298 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 398 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 496 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 348 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 460 KiB

View file

@ -31,24 +31,26 @@ const sidebars = {
"proxy/quick_start", "proxy/quick_start",
"proxy/deploy", "proxy/deploy",
"proxy/prod", "proxy/prod",
"proxy/configs",
{ {
type: "link", type: "link",
label: "📖 All Endpoints", label: "📖 All Endpoints (Swagger)",
href: "https://litellm-api.up.railway.app/", href: "https://litellm-api.up.railway.app/",
}, },
"proxy/enterprise", "proxy/demo",
"proxy/user_keys", "proxy/configs",
"proxy/virtual_keys", "proxy/reliability",
"proxy/users", "proxy/users",
"proxy/user_keys",
"proxy/enterprise",
"proxy/virtual_keys",
"proxy/team_based_routing", "proxy/team_based_routing",
"proxy/ui", "proxy/ui",
"proxy/cost_tracking", "proxy/cost_tracking",
"proxy/token_auth", "proxy/token_auth",
{ {
type: "category", type: "category",
label: "🔥 Load Balancing", label: "Extra Load Balancing",
items: ["proxy/load_balancing", "proxy/reliability"], items: ["proxy/load_balancing"],
}, },
"proxy/model_management", "proxy/model_management",
"proxy/health", "proxy/health",
@ -61,7 +63,7 @@ const sidebars = {
label: "Logging, Alerting", label: "Logging, Alerting",
items: ["proxy/logging", "proxy/alerting", "proxy/streaming_logging"], items: ["proxy/logging", "proxy/alerting", "proxy/streaming_logging"],
}, },
"proxy/grafana_metrics", "proxy/prometheus",
"proxy/call_hooks", "proxy/call_hooks",
"proxy/rules", "proxy/rules",
"proxy/cli", "proxy/cli",
@ -84,6 +86,7 @@ const sidebars = {
"completion/stream", "completion/stream",
"completion/message_trimming", "completion/message_trimming",
"completion/function_call", "completion/function_call",
"completion/vision",
"completion/model_alias", "completion/model_alias",
"completion/batching", "completion/batching",
"completion/mock_requests", "completion/mock_requests",
@ -113,6 +116,7 @@ const sidebars = {
}, },
items: [ items: [
"providers/openai", "providers/openai",
"providers/text_completion_openai",
"providers/openai_compatible", "providers/openai_compatible",
"providers/azure", "providers/azure",
"providers/azure_ai", "providers/azure_ai",
@ -162,7 +166,6 @@ const sidebars = {
"debugging/local_debugging", "debugging/local_debugging",
"observability/callbacks", "observability/callbacks",
"observability/custom_callback", "observability/custom_callback",
"observability/lunary_integration",
"observability/langfuse_integration", "observability/langfuse_integration",
"observability/sentry", "observability/sentry",
"observability/promptlayer_integration", "observability/promptlayer_integration",
@ -171,6 +174,8 @@ const sidebars = {
"observability/slack_integration", "observability/slack_integration",
"observability/traceloop_integration", "observability/traceloop_integration",
"observability/athina_integration", "observability/athina_integration",
"observability/lunary_integration",
"observability/athina_integration",
"observability/helicone_integration", "observability/helicone_integration",
"observability/supabase_integration", "observability/supabase_integration",
`observability/telemetry`, `observability/telemetry`,

View file

@ -16,7 +16,7 @@ However, we also expose 3 public helper functions to calculate token usage acros
```python ```python
from litellm import token_counter from litellm import token_counter
messages = [{"user": "role", "content": "Hey, how's it going"}] messages = [{"role": "user", "content": "Hey, how's it going"}]
print(token_counter(model="gpt-3.5-turbo", messages=messages)) print(token_counter(model="gpt-3.5-turbo", messages=messages))
``` ```

View file

@ -95,7 +95,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
traceback.print_exc() traceback.print_exc()
raise e raise e
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool: def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:
if self.llm_guard_mode == "key-specific": if self.llm_guard_mode == "key-specific":
# check if llm guard enabled for specific keys only # check if llm guard enabled for specific keys only
self.print_verbose( self.print_verbose(
@ -108,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
return True return True
elif self.llm_guard_mode == "all": elif self.llm_guard_mode == "all":
return True return True
elif self.llm_guard_mode == "request-specific":
self.print_verbose(f"received metadata: {data.get('metadata', {})}")
metadata = data.get("metadata", {})
permissions = metadata.get("permissions", {})
if (
"enable_llm_guard_check" in permissions
and permissions["enable_llm_guard_check"] == True
):
return True
return False return False
async def async_moderation_hook( async def async_moderation_hook(
@ -126,7 +135,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
) )
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict) _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data)
if _proceed == False: if _proceed == False:
return return

View file

@ -1,5 +1,6 @@
# Enterprise Proxy Util Endpoints # Enterprise Proxy Util Endpoints
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
import collections
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None): async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
@ -17,6 +18,48 @@ async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
return response return response
async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
response = await prisma_client.db.query_raw(
"""
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
DATE(s."startTime") AS spend_date,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" s
WHERE s."startTime" >= current_date - interval '30 days'
GROUP BY individual_request_tag, spend_date
ORDER BY spend_date;
"""
)
# print("tags - spend")
# print(response)
# Bar Chart 1 - Spend per tag - Top 10 tags by spend
total_spend_per_tag = collections.defaultdict(float)
total_requests_per_tag = collections.defaultdict(int)
for row in response:
tag_name = row["individual_request_tag"]
tag_spend = row["total_spend"]
total_spend_per_tag[tag_name] += tag_spend
total_requests_per_tag[tag_name] += row["log_count"]
sorted_tags = sorted(total_spend_per_tag.items(), key=lambda x: x[1], reverse=True)
# convert to ui format
ui_tags = []
for tag in sorted_tags:
ui_tags.append(
{
"name": tag[0],
"value": tag[1],
"log_count": total_requests_per_tag[tag[0]],
}
)
return {"top_10_tags": ui_tags}
async def view_spend_logs_from_clickhouse( async def view_spend_logs_from_clickhouse(
api_key=None, user_id=None, request_id=None, start_date=None, end_date=None api_key=None, user_id=None, request_id=None, start_date=None, end_date=None
): ):

View file

@ -6,7 +6,7 @@
"": { "": {
"dependencies": { "dependencies": {
"@hono/node-server": "^1.9.0", "@hono/node-server": "^1.9.0",
"hono": "^4.1.5" "hono": "^4.2.7"
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^20.11.17", "@types/node": "^20.11.17",
@ -463,9 +463,9 @@
} }
}, },
"node_modules/hono": { "node_modules/hono": {
"version": "4.1.5", "version": "4.2.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.1.5.tgz", "resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
"integrity": "sha512-3ChJiIoeCxvkt6vnkxJagplrt1YZg3NyNob7ssVeK2PUqEINp4q1F94HzFnvY9QE8asVmbW5kkTDlyWylfg2vg==", "integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
"engines": { "engines": {
"node": ">=16.0.0" "node": ">=16.0.0"
} }

View file

@ -4,7 +4,7 @@
}, },
"dependencies": { "dependencies": {
"@hono/node-server": "^1.9.0", "@hono/node-server": "^1.9.0",
"hono": "^4.1.5" "hono": "^4.2.7"
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^20.11.17", "@types/node": "^20.11.17",

View file

@ -3,7 +3,11 @@ import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings from litellm.proxy._types import (
KeyManagementSystem,
KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams,
)
import httpx import httpx
import dotenv import dotenv
@ -12,10 +16,24 @@ dotenv.load_dotenv()
if set_verbose == True: if set_verbose == True:
_turn_on_debug() _turn_on_debug()
############################################# #############################################
### Callbacks /Logging / Success / Failure Handlers ###
input_callback: List[Union[str, Callable]] = [] input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_langfuse_default_tags: Optional[
List[
Literal[
"user_api_key_alias",
"user_api_key_user_id",
"user_api_key_user_email",
"user_api_key_team_alias",
"semantic-similarity",
"proxy_base_url",
]
]
] = None
_async_input_callback: List[Callable] = ( _async_input_callback: List[Callable] = (
[] []
) # internal variable - async custom callbacks are routed here. ) # internal variable - async custom callbacks are routed here.
@ -27,6 +45,8 @@ _async_failure_callback: List[Callable] = (
) # internal variable - async custom callbacks are routed here. ) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
## end of callbacks #############
email: Optional[str] = ( email: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
) )
@ -46,6 +66,7 @@ replicate_key: Optional[str] = None
cohere_key: Optional[str] = None cohere_key: Optional[str] = None
maritalk_key: Optional[str] = None maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None ai21_key: Optional[str] = None
ollama_key: Optional[str] = None
openrouter_key: Optional[str] = None openrouter_key: Optional[str] = None
huggingface_key: Optional[str] = None huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None vertex_project: Optional[str] = None
@ -56,6 +77,7 @@ baseten_key: Optional[str] = None
aleph_alpha_key: Optional[str] = None aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None nlp_cloud_key: Optional[str] = None
use_client: bool = False use_client: bool = False
ssl_verify: bool = True
disable_streaming_logging: bool = False disable_streaming_logging: bool = False
### GUARDRAILS ### ### GUARDRAILS ###
llamaguard_model_name: Optional[str] = None llamaguard_model_name: Optional[str] = None
@ -64,7 +86,7 @@ google_moderation_confidence_threshold: Optional[float] = None
llamaguard_unsafe_content_categories: Optional[str] = None llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific"] = "all" llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
################## ##################
logging: bool = True logging: bool = True
caching: bool = ( caching: bool = (
@ -76,6 +98,8 @@ caching_with_models: bool = (
cache: Optional[Cache] = ( cache: Optional[Cache] = (
None # cache object <- use this - https://docs.litellm.ai/docs/caching None # cache object <- use this - https://docs.litellm.ai/docs/caching
) )
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers max_budget: float = 0.0 # set the max budget across all providers
@ -170,7 +194,7 @@ dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None s3_callback_params: Optional[Dict] = None
generic_logger_headers: Optional[Dict] = None generic_logger_headers: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[Dict] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
default_user_params: Optional[Dict] = None default_user_params: Optional[Dict] = None
default_team_settings: Optional[List] = None default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
@ -258,6 +282,7 @@ open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = [] open_ai_text_completion_models: List = []
cohere_models: List = [] cohere_models: List = []
cohere_chat_models: List = [] cohere_chat_models: List = []
mistral_chat_models: List = []
anthropic_models: List = [] anthropic_models: List = []
openrouter_models: List = [] openrouter_models: List = []
vertex_language_models: List = [] vertex_language_models: List = []
@ -267,6 +292,7 @@ vertex_code_chat_models: List = []
vertex_text_models: List = [] vertex_text_models: List = []
vertex_code_text_models: List = [] vertex_code_text_models: List = []
vertex_embedding_models: List = [] vertex_embedding_models: List = []
vertex_anthropic_models: List = []
ai21_models: List = [] ai21_models: List = []
nlp_cloud_models: List = [] nlp_cloud_models: List = []
aleph_alpha_models: List = [] aleph_alpha_models: List = []
@ -282,6 +308,8 @@ for key, value in model_cost.items():
cohere_models.append(key) cohere_models.append(key)
elif value.get("litellm_provider") == "cohere_chat": elif value.get("litellm_provider") == "cohere_chat":
cohere_chat_models.append(key) cohere_chat_models.append(key)
elif value.get("litellm_provider") == "mistral":
mistral_chat_models.append(key)
elif value.get("litellm_provider") == "anthropic": elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key) anthropic_models.append(key)
elif value.get("litellm_provider") == "openrouter": elif value.get("litellm_provider") == "openrouter":
@ -300,6 +328,9 @@ for key, value in model_cost.items():
vertex_code_chat_models.append(key) vertex_code_chat_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-embedding-models": elif value.get("litellm_provider") == "vertex_ai-embedding-models":
vertex_embedding_models.append(key) vertex_embedding_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
key = key.replace("vertex_ai/", "")
vertex_anthropic_models.append(key)
elif value.get("litellm_provider") == "ai21": elif value.get("litellm_provider") == "ai21":
ai21_models.append(key) ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud": elif value.get("litellm_provider") == "nlp_cloud":
@ -346,7 +377,7 @@ replicate_models: List = [
"replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b", "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
"joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe", "joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe",
# Flan T-5 # Flan T-5
"daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f" "daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f",
# Others # Others
"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5", "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad", "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
@ -447,6 +478,7 @@ model_list = (
+ deepinfra_models + deepinfra_models
+ perplexity_models + perplexity_models
+ maritalk_models + maritalk_models
+ vertex_language_models
) )
provider_list: List = [ provider_list: List = [
@ -568,6 +600,7 @@ from .utils import (
completion_cost, completion_cost,
supports_function_calling, supports_function_calling,
supports_parallel_function_calling, supports_parallel_function_calling,
supports_vision,
get_litellm_params, get_litellm_params,
Logging, Logging,
acreate, acreate,
@ -585,6 +618,7 @@ from .utils import (
_should_retry, _should_retry,
get_secret, get_secret,
get_supported_openai_params, get_supported_openai_params,
get_api_base,
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig
@ -600,6 +634,7 @@ from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig from .llms.petals import PetalsConfig
from .llms.vertex_ai import VertexAIConfig from .llms.vertex_ai import VertexAIConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig

View file

@ -32,6 +32,25 @@ def _get_redis_kwargs():
return available_args return available_args
def _get_redis_url_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
# Only allow primitive arguments
exclude_args = {
"self",
"connection_pool",
"retry",
}
include_args = ["url"]
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_env_kwarg_mapping(): def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_" PREFIX = "REDIS_"
@ -91,27 +110,39 @@ def _get_redis_client_logic(**env_overrides):
redis_kwargs.pop("password", None) redis_kwargs.pop("password", None)
elif "host" not in redis_kwargs or redis_kwargs["host"] is None: elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.") raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") # litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs return redis_kwargs
def get_redis_client(**env_overrides): def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides) redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None: if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop( args = _get_redis_url_kwargs()
"connection_pool", None url_kwargs = {}
) # redis.from_url doesn't support setting your own connection pool for arg in redis_kwargs:
return redis.Redis.from_url(**redis_kwargs) if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs)
return redis.Redis(**redis_kwargs) return redis.Redis(**redis_kwargs)
def get_redis_async_client(**env_overrides): def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides) redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None: if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop( args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
"connection_pool", None url_kwargs = {}
) # redis.from_url doesn't support setting your own connection pool for arg in redis_kwargs:
return async_redis.Redis.from_url(**redis_kwargs) if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
else:
litellm.print_verbose(
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
arg
)
)
return async_redis.Redis.from_url(**url_kwargs)
return async_redis.Redis( return async_redis.Redis(
socket_timeout=5, socket_timeout=5,
**redis_kwargs, **redis_kwargs,
@ -124,4 +155,9 @@ def get_redis_connection_pool(**env_overrides):
return async_redis.BlockingConnectionPool.from_url( return async_redis.BlockingConnectionPool.from_url(
timeout=5, url=redis_kwargs["url"] timeout=5, url=redis_kwargs["url"]
) )
connection_class = async_redis.Connection
if "ssl" in redis_kwargs and redis_kwargs["ssl"] is not None:
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

130
litellm/_service_logger.py Normal file
View file

@ -0,0 +1,130 @@
import litellm, traceback
from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger
from datetime import timedelta
from typing import Union
class ServiceLogging(CustomLogger):
"""
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
"""
def __init__(self, mock_testing: bool = False) -> None:
self.mock_testing = mock_testing
self.mock_testing_sync_success_hook = 0
self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger()
def service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
if self.mock_testing:
self.mock_testing_sync_success_hook += 1
def service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
if self.mock_testing:
self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(
self, service: ServiceTypes, duration: float, call_type: str
):
"""
- For counting if the redis, postgres call is successful
"""
if self.mock_testing:
self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload(
is_error=False,
error=None,
service=service,
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.prometheusServicesLogger.async_service_success_hook(
payload=payload
)
async def async_service_failure_hook(
self,
service: ServiceTypes,
duration: float,
error: Union[str, Exception],
call_type: str,
):
"""
- For counting if the redis, postgres call is unsuccessful
"""
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
error_message = ""
if isinstance(error, Exception):
error_message = str(error)
elif isinstance(error, str):
error_message = error
payload = ServiceLoggerPayload(
is_error=True,
error=error_message,
service=service,
duration=duration,
call_type=call_type,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
if self.prometheusServicesLogger is None:
self.prometheusServicesLogger = self.prometheusServicesLogger()
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload
)
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Hook to track failed litellm-service calls
"""
return await super().async_post_call_failure_hook(
original_exception, user_api_key_dict
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Hook to track latency for litellm proxy llm api calls
"""
try:
_duration = end_time - start_time
if isinstance(_duration, timedelta):
_duration = _duration.total_seconds()
elif isinstance(_duration, float):
pass
else:
raise Exception(
"Duration={} is not a float or timedelta object. type={}".format(
_duration, type(_duration)
)
) # invalid _duration value
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs["call_type"],
)
except Exception as e:
raise e

View file

@ -13,6 +13,7 @@ import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any, BinaryIO from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
import traceback import traceback
@ -81,9 +82,37 @@ class InMemoryCache(BaseCache):
return cached_response return cached_response
return None return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs) return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self): def flush_cache(self):
self.cache_dict.clear() self.cache_dict.clear()
self.ttl_dict.clear() self.ttl_dict.clear()
@ -109,6 +138,8 @@ class RedisCache(BaseCache):
**kwargs, **kwargs,
): ):
from ._redis import get_redis_client, get_redis_connection_pool from ._redis import get_redis_client, get_redis_connection_pool
from litellm._service_logger import ServiceLogging
import redis
redis_kwargs = {} redis_kwargs = {}
if host is not None: if host is not None:
@ -118,10 +149,19 @@ class RedisCache(BaseCache):
if password is not None: if password is not None:
redis_kwargs["password"] = password redis_kwargs["password"] = password
### HEALTH MONITORING OBJECT ###
if kwargs.get("service_logger_obj", None) is not None and isinstance(
kwargs["service_logger_obj"], ServiceLogging
):
self.service_logger_obj = kwargs.pop("service_logger_obj")
else:
self.service_logger_obj = ServiceLogging()
redis_kwargs.update(kwargs) redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs) self.redis_client = get_redis_client(**redis_kwargs)
self.redis_kwargs = redis_kwargs self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
# redis namespaces # redis namespaces
self.namespace = namespace self.namespace = namespace
# for high traffic, we store the redis results in memory and then batch write to redis # for high traffic, we store the redis results in memory and then batch write to redis
@ -133,6 +173,16 @@ class RedisCache(BaseCache):
except Exception as e: except Exception as e:
pass pass
### ASYNC HEALTH PING ###
try:
# asyncio.get_running_loop().create_task(self.ping())
result = asyncio.get_running_loop().create_task(self.ping())
except Exception:
pass
### SYNC HEALTH PING ###
self.redis_client.ping()
def init_async_client(self): def init_async_client(self):
from ._redis import get_redis_async_client from ._redis import get_redis_async_client
@ -163,18 +213,101 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
) )
def increment_cache(self, key, value: int, **kwargs) -> int:
_redis_client = self.redis_client
start_time = time.time()
try:
result = _redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="increment_cache",
)
)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="increment_cache",
)
)
verbose_logger.error(
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list: async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
keys = [] start_time = time.time()
_redis_client = self.init_async_client() try:
async with _redis_client as redis_client: keys = []
async for key in redis_client.scan_iter(match=pattern + "*", count=count): _redis_client = self.init_async_client()
keys.append(key) async with _redis_client as redis_client:
if len(keys) >= count: async for key in redis_client.scan_iter(
break match=pattern + "*", count=count
return keys ):
keys.append(key)
if len(keys) >= count:
break
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_scan_iter",
)
) # DO NOT SLOW DOWN CALL B/C OF THIS
return keys
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_scan_iter",
)
)
raise e
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
_redis_client = self.init_async_client() start_time = time.time()
try:
_redis_client = self.init_async_client()
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS, duration=_duration, error=e
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
key = self.check_and_fix_namespace(key=key) key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client: async with _redis_client as redis_client:
ttl = kwargs.get("ttl", None) ttl = kwargs.get("ttl", None)
@ -186,7 +319,26 @@ class RedisCache(BaseCache):
print_verbose( print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
) )
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache",
)
)
except Exception as e: except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache",
)
)
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
verbose_logger.error( verbose_logger.error(
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s", "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
@ -200,6 +352,11 @@ class RedisCache(BaseCache):
Use Redis Pipelines for bulk write operations Use Redis Pipelines for bulk write operations
""" """
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
start_time = time.time()
print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
try: try:
async with _redis_client as redis_client: async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe: async with redis_client.pipeline(transaction=True) as pipe:
@ -219,8 +376,30 @@ class RedisCache(BaseCache):
print_verbose(f"pipeline results: {results}") print_verbose(f"pipeline results: {results}")
# Optionally, you could process 'results' to make sure that all set operations were successful. # Optionally, you could process 'results' to make sure that all set operations were successful.
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache_pipeline",
)
)
return results return results
except Exception as e: except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache_pipeline",
)
)
verbose_logger.error( verbose_logger.error(
"LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s", "LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s",
str(e), str(e),
@ -235,7 +414,44 @@ class RedisCache(BaseCache):
key = self.check_and_fix_namespace(key=key) key = self.check_and_fix_namespace(key=key)
self.redis_batch_writing_buffer.append((key, value)) self.redis_batch_writing_buffer.append((key, value))
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() await self.flush_cache_buffer() # logging done in here
async def async_increment(self, key, value: int, **kwargs) -> int:
_redis_client = self.init_async_client()
start_time = time.time()
try:
async with _redis_client as redis_client:
result = await redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_increment",
)
)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_increment",
)
)
verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
raise e
async def flush_cache_buffer(self): async def flush_cache_buffer(self):
print_verbose( print_verbose(
@ -274,40 +490,17 @@ class RedisCache(BaseCache):
traceback.print_exc() traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
async def async_get_cache(self, key, **kwargs): def batch_get_cache(self, key_list) -> dict:
_redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
try:
print_verbose(f"Get Async Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print_verbose(
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
)
async def async_get_cache_pipeline(self, key_list) -> dict:
""" """
Use Redis for bulk read operations Use Redis for bulk read operations
""" """
_redis_client = await self.init_async_client()
key_value_dict = {} key_value_dict = {}
try: try:
async with _redis_client as redis_client: _keys = []
async with redis_client.pipeline(transaction=True) as pipe: for cache_key in key_list:
# Queue the get operations in the pipeline for all keys. cache_key = self.check_and_fix_namespace(key=cache_key)
for cache_key in key_list: _keys.append(cache_key)
cache_key = self.check_and_fix_namespace(key=cache_key) results = self.redis_client.mget(keys=_keys)
pipe.get(cache_key) # Queue GET command in pipeline
# Execute the pipeline and await the results.
results = await pipe.execute()
# Associate the results back with their keys. # Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'. # 'results' is a list of values corresponding to the order of keys in 'key_list'.
@ -323,21 +516,185 @@ class RedisCache(BaseCache):
print_verbose(f"Error occurred in pipeline read - {str(e)}") print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict return key_value_dict
async def ping(self): async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
start_time = time.time()
async with _redis_client as redis_client:
try:
print_verbose(f"Get Async Redis Cache: key: {key}")
cached_response = await redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
response = self._get_cache_logic(cached_response=cached_response)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_get_cache",
)
)
return response
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_get_cache",
)
)
# NON blocking - notify users Redis is throwing an exception
print_verbose(
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
)
async def async_batch_get_cache(self, key_list) -> dict:
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
key_value_dict = {}
start_time = time.time()
try:
async with _redis_client as redis_client:
_keys = []
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = await redis_client.mget(keys=_keys)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_batch_get_cache",
)
)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
decoded_results = {}
for k, v in key_value_dict.items():
if isinstance(k, bytes):
k = k.decode("utf-8")
v = self._get_cache_logic(v)
decoded_results[k] = v
return decoded_results
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_batch_get_cache",
)
)
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
def sync_ping(self) -> bool:
"""
Tests if the sync redis client is correctly setup.
"""
print_verbose(f"Pinging Sync Redis Cache")
start_time = time.time()
try:
response = self.redis_client.ping()
print_verbose(f"Redis Cache PING: {response}")
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="sync_ping",
)
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
self.service_logger_obj.service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="sync_ping",
)
print_verbose(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
traceback.print_exc()
raise e
async def ping(self) -> bool:
_redis_client = self.init_async_client()
start_time = time.time()
async with _redis_client as redis_client: async with _redis_client as redis_client:
print_verbose(f"Pinging Async Redis Cache") print_verbose(f"Pinging Async Redis Cache")
try: try:
response = await redis_client.ping() response = await redis_client.ping()
print_verbose(f"Redis Cache PING: {response}") ## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_ping",
)
)
return response
except Exception as e: except Exception as e:
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_ping",
)
)
print_verbose( print_verbose(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}" f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
) )
traceback.print_exc() traceback.print_exc()
raise e raise e
async def delete_cache_keys(self, keys):
_redis_client = self.init_async_client()
# keys is a list, unpack it so it gets passed as individual elements to delete
async with _redis_client as redis_client:
await redis_client.delete(*keys)
def client_list(self):
client_list = self.redis_client.client_list()
return client_list
def info(self):
info = self.redis_client.info()
return info
def flush_cache(self): def flush_cache(self):
self.redis_client.flushall() self.redis_client.flushall()
@ -828,8 +1185,10 @@ class DualCache(BaseCache):
# If redis_cache is not provided, use the default RedisCache # If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache self.redis_cache = redis_cache
self.default_in_memory_ttl = default_in_memory_ttl self.default_in_memory_ttl = (
self.default_redis_ttl = default_redis_ttl default_in_memory_ttl or litellm.default_in_memory_ttl
)
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
def set_cache(self, key, value, local_only: bool = False, **kwargs): def set_cache(self, key, value, local_only: bool = False, **kwargs):
# Update both Redis and in-memory cache # Update both Redis and in-memory cache
@ -846,6 +1205,30 @@ class DualCache(BaseCache):
except Exception as e: except Exception as e:
print_verbose(e) print_verbose(e)
def increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e
def get_cache(self, key, local_only: bool = False, **kwargs): def get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
@ -872,6 +1255,39 @@ class DualCache(BaseCache):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
for key, value in redis_result.items():
result[keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_get_cache(self, key, local_only: bool = False, **kwargs): async def async_get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
@ -905,7 +1321,50 @@ class DualCache(BaseCache):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs
):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
keys, **kwargs
)
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, **kwargs
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key, value in redis_result.items():
if value is not None:
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
for key, value in redis_result.items():
index = keys.index(key)
result[index] = value
return result
except Exception as e:
traceback.print_exc()
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose(
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
)
try: try:
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(key, value, **kwargs) await self.in_memory_cache.async_set_cache(key, value, **kwargs)
@ -916,6 +1375,32 @@ class DualCache(BaseCache):
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc() traceback.print_exc()
async def async_increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = await self.in_memory_cache.async_increment(
key, value, **kwargs
)
if self.redis_cache is not None and local_only == False:
result = await self.redis_cache.async_increment(key, value, **kwargs)
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e
def flush_cache(self): def flush_cache(self):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache() self.in_memory_cache.flush_cache()
@ -939,6 +1424,8 @@ class Cache:
password: Optional[str] = None, password: Optional[str] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
ttl: Optional[float] = None, ttl: Optional[float] = None,
default_in_memory_ttl: Optional[float] = None,
default_in_redis_ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None, similarity_threshold: Optional[float] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[ List[
@ -1038,6 +1525,14 @@ class Cache:
self.redis_flush_size = redis_flush_size self.redis_flush_size = redis_flush_size
self.ttl = ttl self.ttl = ttl
if self.type == "local" and default_in_memory_ttl is not None:
self.ttl = default_in_memory_ttl
if (
self.type == "redis" or self.type == "redis-semantic"
) and default_in_redis_ttl is not None:
self.ttl = default_in_redis_ttl
if self.namespace is not None and isinstance(self.cache, RedisCache): if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace self.cache.namespace = self.namespace
@ -1379,6 +1874,11 @@ class Cache:
return await self.cache.ping() return await self.cache.ping()
return None return None
async def delete_cache_keys(self, keys):
if hasattr(self.cache, "delete_cache_keys"):
return await self.cache.delete_cache_keys(keys)
return None
async def disconnect(self): async def disconnect(self):
if hasattr(self.cache, "disconnect"): if hasattr(self.cache, "disconnect"):
await self.cache.disconnect() await self.cache.disconnect()

View file

@ -82,14 +82,18 @@ class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
class Timeout(APITimeoutError): # type: ignore class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 408
self.message = message
self.model = model
self.llm_provider = llm_provider
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__( super().__init__(
request=request request=request
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
self.status_code = 408
self.message = message
self.model = model
self.llm_provider = llm_provider
# custom function to convert to str
def __str__(self):
return str(self.message)
class PermissionDeniedError(PermissionDeniedError): # type:ignore class PermissionDeniedError(PermissionDeniedError): # type:ignore

View file

@ -6,7 +6,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass pass
#### PRE-CALL CHECKS - router/proxy only ####
"""
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
"""
async def async_pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
def pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####
""" """
Control the modify incoming / outgoung data before calling the model Control the modify incoming / outgoung data before calling the model

View file

@ -0,0 +1,51 @@
import requests
import json
import traceback
from datetime import datetime, timezone
class GreenscaleLogger:
def __init__(self):
import os
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
self.headers = {
"api-key": self.greenscale_api_key,
"Content-Type": "application/json"
}
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try:
response_json = response_obj.model_dump() if response_obj else {}
data = {
"modelId": kwargs.get("model"),
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
"outputTokenCount": response_json.get("usage", {}).get("completion_tokens"),
}
data["timestamp"] = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
if type(end_time) == datetime and type(start_time) == datetime:
data["invocationLatency"] = int((end_time - start_time).total_seconds() * 1000)
# Add additional metadata keys to tags
tags = []
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
for key, value in metadata.items():
if key.startswith("greenscale"):
if key == "greenscale_project":
data["project"] = value
elif key == "greenscale_application":
data["application"] = value
else:
tags.append({"key": key.replace("greenscale_", ""), "value": str(value)})
data["tags"] = tags
response = requests.post(self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str))
if response.status_code != 200:
print_verbose(f"Greenscale Logger Error - {response.text}, {response.status_code}")
else:
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
except Exception as e:
print_verbose(f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}")
pass

View file

@ -17,7 +17,7 @@ class LangFuseLogger:
from langfuse import Langfuse from langfuse import Langfuse
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\033[0m" f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
) )
# Instance variables # Instance variables
self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY") self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
@ -34,6 +34,14 @@ class LangFuseLogger:
flush_interval=1, # flush interval in seconds flush_interval=1, # flush interval in seconds
) )
# set the current langfuse project id in the environ
# this is used by Alerting to link to the correct project
try:
project_id = self.Langfuse.client.projects.get().data[0].id
os.environ["LANGFUSE_PROJECT_ID"] = project_id
except:
project_id = None
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None: if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
self.upstream_langfuse_secret_key = os.getenv( self.upstream_langfuse_secret_key = os.getenv(
"UPSTREAM_LANGFUSE_SECRET_KEY" "UPSTREAM_LANGFUSE_SECRET_KEY"
@ -76,6 +84,7 @@ class LangFuseLogger:
print_verbose( print_verbose(
f"Langfuse Logging - Enters logging function for model {kwargs}" f"Langfuse Logging - Enters logging function for model {kwargs}"
) )
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = ( metadata = (
litellm_params.get("metadata", {}) or {} litellm_params.get("metadata", {}) or {}
@ -118,6 +127,11 @@ class LangFuseLogger:
): ):
input = prompt input = prompt
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = response_obj.choices[0].text
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse response_obj, litellm.ImageResponse
): ):
@ -128,6 +142,7 @@ class LangFuseLogger:
self._log_langfuse_v2( self._log_langfuse_v2(
user_id, user_id,
metadata, metadata,
litellm_params,
output, output,
start_time, start_time,
end_time, end_time,
@ -156,7 +171,7 @@ class LangFuseLogger:
verbose_logger.info(f"Langfuse Layer Logging - logging success") verbose_logger.info(f"Langfuse Layer Logging - logging success")
except: except:
traceback.print_exc() traceback.print_exc()
print(f"Langfuse Layer Error - {traceback.format_exc()}") verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
pass pass
async def _async_log_event( async def _async_log_event(
@ -185,7 +200,7 @@ class LangFuseLogger:
): ):
from langfuse.model import CreateTrace, CreateGeneration from langfuse.model import CreateTrace, CreateGeneration
print( verbose_logger.warning(
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1" "Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
) )
@ -219,6 +234,7 @@ class LangFuseLogger:
self, self,
user_id, user_id,
metadata, metadata,
litellm_params,
output, output,
start_time, start_time,
end_time, end_time,
@ -273,13 +289,13 @@ class LangFuseLogger:
clean_metadata = {} clean_metadata = {}
if isinstance(metadata, dict): if isinstance(metadata, dict):
for key, value in metadata.items(): for key, value in metadata.items():
# generate langfuse tags
if key in [ # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
"user_api_key", if (
"user_api_key_user_id", litellm._langfuse_default_tags is not None
"user_api_key_team_id", and isinstance(litellm._langfuse_default_tags, list)
"semantic-similarity", and key in litellm._langfuse_default_tags
]: ):
tags.append(f"{key}:{value}") tags.append(f"{key}:{value}")
# clean litellm metadata before logging # clean litellm metadata before logging
@ -293,13 +309,55 @@ class LangFuseLogger:
else: else:
clean_metadata[key] = value clean_metadata[key] = value
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
and "proxy_base_url" in litellm._langfuse_default_tags
):
proxy_base_url = os.environ.get("PROXY_BASE_URL", None)
if proxy_base_url is not None:
tags.append(f"proxy_base_url:{proxy_base_url}")
api_base = litellm_params.get("api_base", None)
if api_base:
clean_metadata["api_base"] = api_base
vertex_location = kwargs.get("vertex_location", None)
if vertex_location:
clean_metadata["vertex_location"] = vertex_location
aws_region_name = kwargs.get("aws_region_name", None)
if aws_region_name:
clean_metadata["aws_region_name"] = aws_region_name
if supports_tags: if supports_tags:
if "cache_hit" in kwargs: if "cache_hit" in kwargs:
if kwargs["cache_hit"] is None: if kwargs["cache_hit"] is None:
kwargs["cache_hit"] = False kwargs["cache_hit"] = False
tags.append(f"cache_hit:{kwargs['cache_hit']}") tags.append(f"cache_hit:{kwargs['cache_hit']}")
clean_metadata["cache_hit"] = kwargs["cache_hit"]
trace_params.update({"tags": tags}) trace_params.update({"tags": tags})
proxy_server_request = litellm_params.get("proxy_server_request", None)
if proxy_server_request:
method = proxy_server_request.get("method", None)
url = proxy_server_request.get("url", None)
headers = proxy_server_request.get("headers", None)
clean_headers = {}
if headers:
for key, value in headers.items():
# these headers can leak our API keys and/or JWT tokens
if key.lower() not in ["authorization", "cookie", "referer"]:
clean_headers[key] = value
clean_metadata["request"] = {
"method": method,
"url": url,
"headers": clean_headers,
}
print_verbose(f"trace_params: {trace_params}")
trace = self.Langfuse.trace(**trace_params) trace = self.Langfuse.trace(**trace_params)
generation_id = None generation_id = None
@ -316,13 +374,21 @@ class LangFuseLogger:
# just log `litellm-{call_type}` as the generation name # just log `litellm-{call_type}` as the generation name
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}" generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
if response_obj is not None and "system_fingerprint" in response_obj:
system_fingerprint = response_obj.get("system_fingerprint", None)
else:
system_fingerprint = None
if system_fingerprint is not None:
optional_params["system_fingerprint"] = system_fingerprint
generation_params = { generation_params = {
"name": generation_name, "name": generation_name,
"id": metadata.get("generation_id", generation_id), "id": metadata.get("generation_id", generation_id),
"startTime": start_time, "start_time": start_time,
"endTime": end_time, "end_time": end_time,
"model": kwargs["model"], "model": kwargs["model"],
"modelParameters": optional_params, "model_parameters": optional_params,
"input": input, "input": input,
"output": output, "output": output,
"usage": usage, "usage": usage,
@ -334,13 +400,15 @@ class LangFuseLogger:
generation_params["prompt"] = metadata.get("prompt", None) generation_params["prompt"] = metadata.get("prompt", None)
if output is not None and isinstance(output, str) and level == "ERROR": if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["statusMessage"] = output generation_params["status_message"] = output
if supports_completion_start_time: if supports_completion_start_time:
generation_params["completion_start_time"] = kwargs.get( generation_params["completion_start_time"] = kwargs.get(
"completion_start_time", None "completion_start_time", None
) )
print_verbose(f"generation_params: {generation_params}")
trace.generation(**generation_params) trace.generation(**generation_params)
except Exception as e: except Exception as e:
print(f"Langfuse Layer Error - {traceback.format_exc()}") verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")

View file

@ -7,6 +7,19 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import asyncio
import types
from pydantic import BaseModel
def is_serializable(value):
non_serializable_types = (
types.CoroutineType,
types.FunctionType,
types.GeneratorType,
BaseModel,
)
return not isinstance(value, non_serializable_types)
class LangsmithLogger: class LangsmithLogger:
@ -21,7 +34,9 @@ class LangsmithLogger:
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition # Method definition
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb # inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
metadata = kwargs.get('litellm_params', {}).get("metadata", {}) or {} # if metadata is None metadata = (
kwargs.get("litellm_params", {}).get("metadata", {}) or {}
) # if metadata is None
# set project name and run_name for langsmith logging # set project name and run_name for langsmith logging
# users can pass project_name and run name to litellm.completion() # users can pass project_name and run name to litellm.completion()
@ -51,24 +66,46 @@ class LangsmithLogger:
new_kwargs = {} new_kwargs = {}
for key in kwargs: for key in kwargs:
value = kwargs[key] value = kwargs[key]
if key == "start_time" or key == "end_time": if key == "start_time" or key == "end_time" or value is None:
pass pass
elif type(value) != dict: elif type(value) == datetime.datetime:
new_kwargs[key] = value.isoformat()
elif type(value) != dict and is_serializable(value=value):
new_kwargs[key] = value new_kwargs[key] = value
requests.post( print(f"type of response: {type(response_obj)}")
for k, v in new_kwargs.items():
print(f"key={k}, type of arg: {type(v)}, value={v}")
if isinstance(response_obj, BaseModel):
try:
response_obj = response_obj.model_dump()
except:
response_obj = response_obj.dict() # type: ignore
print(f"response_obj: {response_obj}")
data = {
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": new_kwargs,
"outputs": response_obj,
"session_name": project_name,
"start_time": start_time,
"end_time": end_time,
}
print(f"data: {data}")
response = requests.post(
"https://api.smith.langchain.com/runs", "https://api.smith.langchain.com/runs",
json={ json=data,
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": {**new_kwargs},
"outputs": response_obj.json(),
"session_name": project_name,
"start_time": start_time,
"end_time": end_time,
},
headers={"x-api-key": self.langsmith_api_key}, headers={"x-api-key": self.langsmith_api_key},
) )
if response.status_code >= 300:
print_verbose(f"Error: {response.status_code}")
else:
print_verbose("Run successfully created")
print_verbose( print_verbose(
f"Langsmith Layer Logging - final response object: {response_obj}" f"Langsmith Layer Logging - final response object: {response_obj}"
) )

View file

@ -4,11 +4,13 @@ from datetime import datetime, timezone
import traceback import traceback
import dotenv import dotenv
import importlib import importlib
from pkg_resources import parse_version
import sys import sys
import packaging
dotenv.load_dotenv() dotenv.load_dotenv()
# convert to {completion: xx, tokens: xx} # convert to {completion: xx, tokens: xx}
def parse_usage(usage): def parse_usage(usage):
return { return {
@ -16,6 +18,7 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
} }
def parse_messages(input): def parse_messages(input):
if input is None: if input is None:
return None return None
@ -28,7 +31,6 @@ def parse_messages(input):
if "message" in message: if "message" in message:
return clean_message(message["message"]) return clean_message(message["message"])
serialized = { serialized = {
"role": message.get("role"), "role": message.get("role"),
"content": message.get("content"), "content": message.get("content"),
@ -56,10 +58,13 @@ class LunaryLogger:
def __init__(self): def __init__(self):
try: try:
import lunary import lunary
version = importlib.metadata.version("lunary") version = importlib.metadata.version("lunary")
# if version < 0.1.43 then raise ImportError # if version < 0.1.43 then raise ImportError
if parse_version(version) < parse_version("0.1.43"): if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
print("Lunary version outdated. Required: > 0.1.43. Upgrade via 'pip install lunary --upgrade'") print(
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
)
raise ImportError raise ImportError
self.lunary_client = lunary self.lunary_client = lunary
@ -88,9 +93,7 @@ class LunaryLogger:
print_verbose(f"Lunary Logging - Logging request for model {model}") print_verbose(f"Lunary Logging - Logging request for model {model}")
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = ( metadata = litellm_params.get("metadata", {}) or {}
litellm_params.get("metadata", {}) or {}
)
tags = litellm_params.pop("tags", None) or [] tags = litellm_params.pop("tags", None) or []
@ -148,7 +151,7 @@ class LunaryLogger:
runtime="litellm", runtime="litellm",
error=error_obj, error=error_obj,
output=parse_messages(output), output=parse_messages(output),
token_usage=usage token_usage=usage,
) )
except: except:

View file

@ -1,6 +1,6 @@
# used for /metrics endpoint on LiteLLM Proxy # used for /metrics endpoint on LiteLLM Proxy
#### What this does #### #### What this does ####
# On success + failure, log events to Supabase # On success, log events to Prometheus
import dotenv, os import dotenv, os
import requests import requests
@ -19,27 +19,33 @@ class PrometheusLogger:
**kwargs, **kwargs,
): ):
try: try:
verbose_logger.debug(f"in init prometheus metrics") print(f"in init prometheus metrics")
from prometheus_client import Counter from prometheus_client import Counter
self.litellm_llm_api_failed_requests_metric = Counter(
name="litellm_llm_api_failed_requests_metric",
documentation="Total number of failed LLM API calls via litellm",
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
)
self.litellm_requests_metric = Counter( self.litellm_requests_metric = Counter(
name="litellm_requests_metric", name="litellm_requests_metric",
documentation="Total number of LLM calls to litellm", documentation="Total number of LLM calls to litellm",
labelnames=["user", "key", "model"], labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
) )
# Counter for spend # Counter for spend
self.litellm_spend_metric = Counter( self.litellm_spend_metric = Counter(
"litellm_spend_metric", "litellm_spend_metric",
"Total spend on LLM requests", "Total spend on LLM requests",
labelnames=["user", "key", "model"], labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
) )
# Counter for total_output_tokens # Counter for total_output_tokens
self.litellm_tokens_metric = Counter( self.litellm_tokens_metric = Counter(
"litellm_total_tokens", "litellm_total_tokens",
"Total number of input + output tokens from LLM requests", "Total number of input + output tokens from LLM requests",
labelnames=["user", "key", "model"], labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
) )
except Exception as e: except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}") print_verbose(f"Got exception on init prometheus client {str(e)}")
@ -61,24 +67,50 @@ class PrometheusLogger:
# unpack kwargs # unpack kwargs
model = kwargs.get("model", "") model = kwargs.get("model", "")
response_cost = kwargs.get("response_cost", 0.0) response_cost = kwargs.get("response_cost", 0.0) or 0
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {} proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None) end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params.get("metadata", {}).get(
"user_api_key_user_id", None
)
user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None) user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None)
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0) user_api_team = litellm_params.get("metadata", {}).get(
"user_api_key_team_id", None
)
if response_obj is not None:
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0)
else:
tokens_used = 0
print_verbose( print_verbose(
f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}" f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}"
) )
self.litellm_requests_metric.labels(end_user_id, user_api_key, model).inc() if (
self.litellm_spend_metric.labels(end_user_id, user_api_key, model).inc( user_api_key is not None
response_cost and isinstance(user_api_key, str)
) and user_api_key.startswith("sk-")
self.litellm_tokens_metric.labels(end_user_id, user_api_key, model).inc( ):
tokens_used from litellm.proxy.utils import hash_token
)
user_api_key = hash_token(user_api_key)
self.litellm_requests_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc()
self.litellm_spend_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc(response_cost)
self.litellm_tokens_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc(tokens_used)
### FAILURE INCREMENT ###
if "exception" in kwargs:
self.litellm_llm_api_failed_requests_metric.labels(
end_user_id, user_api_key, model, user_api_team, user_id
).inc()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
verbose_logger.debug( verbose_logger.debug(

View file

@ -0,0 +1,198 @@
# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy
#### What this does ####
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose, verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
class PrometheusServicesLogger:
# Class variables or attributes
litellm_service_latency = None # Class-level attribute to store the Histogram
def __init__(
self,
mock_testing: bool = False,
**kwargs,
):
try:
try:
from prometheus_client import Counter, Histogram, REGISTRY
except ImportError:
raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`"
)
self.Histogram = Histogram
self.Counter = Counter
self.REGISTRY = REGISTRY
verbose_logger.debug(f"in init prometheus services metrics")
self.services = [item.value for item in ServiceTypes]
self.payload_to_prometheus_map = (
{}
) # store the prometheus histogram/counter we need to call for each field in payload
for service in self.services:
histogram = self.create_histogram(service, type_of_request="latency")
counter_failed_request = self.create_counter(
service, type_of_request="failed_requests"
)
counter_total_requests = self.create_counter(
service, type_of_request="total_requests"
)
self.payload_to_prometheus_map[service] = [
histogram,
counter_failed_request,
counter_total_requests,
]
self.prometheus_to_amount_map: dict = (
{}
) # the field / value in ServiceLoggerPayload the object needs to be incremented by
### MOCK TESTING ###
self.mock_testing = mock_testing
self.mock_testing_success_calls = 0
self.mock_testing_failure_calls = 0
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
def is_metric_registered(self, metric_name) -> bool:
for metric in self.REGISTRY.collect():
if metric_name == metric.name:
return True
return False
def get_metric(self, metric_name):
for metric in self.REGISTRY.collect():
for sample in metric.samples:
if metric_name == sample.name:
return metric
return None
def create_histogram(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Histogram(
metric_name,
"Latency for {} service".format(service),
labelnames=[service],
)
def create_counter(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Counter(
metric_name,
"Total {} for {} service".format(type_of_request, service),
labelnames=[service],
)
def observe_histogram(
self,
histogram,
labels: str,
amount: float,
):
assert isinstance(histogram, self.Histogram)
histogram.labels(labels).observe(amount)
def increment_counter(
self,
counter,
labels: str,
amount: float,
):
assert isinstance(counter, self.Counter)
counter.labels(labels).inc(amount)
def service_success_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_success_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Histogram):
self.observe_histogram(
histogram=obj,
labels=payload.service.value,
amount=payload.duration,
)
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
def service_failure_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_failure_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Counter):
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG ERROR COUNT / TOTAL REQUESTS TO PROMETHEUS
)
async def async_service_success_hook(self, payload: ServiceLoggerPayload):
"""
Log successful call to prometheus
"""
if self.mock_testing:
self.mock_testing_success_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Histogram):
self.observe_histogram(
histogram=obj,
labels=payload.service.value,
amount=payload.duration,
)
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
print(f"received error payload: {payload.error}")
if self.mock_testing:
self.mock_testing_failure_calls += 1
if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Counter):
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG ERROR COUNT TO PROMETHEUS
)

View file

@ -0,0 +1,486 @@
#### What this does ####
# Class for sending Slack Alerts #
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import copy
import traceback
from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm
from typing import List, Literal, Any, Union, Optional, Dict
from litellm.caching import DualCache
import asyncio
import aiohttp
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
class SlackAlerting:
# Class variables or attributes
def __init__(
self,
alerting_threshold: float = 300,
alerting: Optional[List] = [],
alert_types: Optional[
List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
]
] = [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
],
alert_to_webhook_url: Optional[
Dict
] = None, # if user wants to separate alerts to diff channels
):
self.alerting_threshold = alerting_threshold
self.alerting = alerting
self.alert_types = alert_types
self.internal_usage_cache = DualCache()
self.async_http_handler = AsyncHTTPHandler()
self.alert_to_webhook_url = alert_to_webhook_url
pass
def update_values(
self,
alerting: Optional[List] = None,
alerting_threshold: Optional[float] = None,
alert_types: Optional[List] = None,
alert_to_webhook_url: Optional[Dict] = None,
):
if alerting is not None:
self.alerting = alerting
if alerting_threshold is not None:
self.alerting_threshold = alerting_threshold
if alert_types is not None:
self.alert_types = alert_types
if alert_to_webhook_url is not None:
# update the dict
if self.alert_to_webhook_url is None:
self.alert_to_webhook_url = alert_to_webhook_url
else:
self.alert_to_webhook_url.update(alert_to_webhook_url)
async def deployment_in_cooldown(self):
pass
async def deployment_removed_from_cooldown(self):
pass
def _all_possible_alert_types(self):
# used by the UI to show all supported alert types
# Note: This is not the alerts the user has configured, instead it's all possible alert types a user can select
return [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
]
def _add_langfuse_trace_id_to_alert(
self,
request_info: str,
request_data: Optional[dict] = None,
kwargs: Optional[dict] = None,
):
import uuid
# For now: do nothing as we're debugging why this is not working as expected
return request_info
# if request_data is not None:
# trace_id = request_data.get("metadata", {}).get(
# "trace_id", None
# ) # get langfuse trace id
# if trace_id is None:
# trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
# request_data["metadata"]["trace_id"] = trace_id
# elif kwargs is not None:
# _litellm_params = kwargs.get("litellm_params", {})
# trace_id = _litellm_params.get("metadata", {}).get(
# "trace_id", None
# ) # get langfuse trace id
# if trace_id is None:
# trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
# _litellm_params["metadata"]["trace_id"] = trace_id
# _langfuse_host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
# _langfuse_project_id = os.environ.get("LANGFUSE_PROJECT_ID")
# # langfuse urls look like: https://us.cloud.langfuse.com/project/************/traces/litellm-alert-trace-ididi9dk-09292-************
# _langfuse_url = (
# f"{_langfuse_host}/project/{_langfuse_project_id}/traces/{trace_id}"
# )
# request_info += f"\n🪢 Langfuse Trace: {_langfuse_url}"
# return request_info
def _response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
start_time,
end_time, # start/end time
):
try:
time_difference = end_time - start_time
# Convert the timedelta to float (in seconds)
time_difference_float = time_difference.total_seconds()
litellm_params = kwargs.get("litellm_params", {})
model = kwargs.get("model", "")
api_base = litellm.get_api_base(model=model, optional_params=litellm_params)
messages = kwargs.get("messages", None)
# if messages does not exist fallback to "input"
if messages is None:
messages = kwargs.get("input", None)
# only use first 100 chars for alerting
_messages = str(messages)[:100]
return time_difference_float, model, api_base, _messages
except Exception as e:
raise e
def _get_deployment_latencies_to_alert(self, metadata=None):
if metadata is None:
return None
if "_latency_per_deployment" in metadata:
# Translate model_id to -> api_base
# _latency_per_deployment is a dictionary that looks like this:
"""
_latency_per_deployment: {
api_base: 0.01336697916666667
}
"""
_message_to_send = ""
_deployment_latencies = metadata["_latency_per_deployment"]
if len(_deployment_latencies) == 0:
return None
for api_base, latency in _deployment_latencies.items():
_message_to_send += f"\n{api_base}: {round(latency,2)}s"
_message_to_send = "```" + _message_to_send + "```"
return _message_to_send
async def response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
if self.alerting is None or self.alert_types is None:
return
time_difference_float, model, api_base, messages = (
self._response_taking_too_long_callback(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
)
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold:
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, kwargs=kwargs
)
# add deployment latencies to alert
if (
kwargs is not None
and "litellm_params" in kwargs
and "metadata" in kwargs["litellm_params"]
):
_metadata = kwargs["litellm_params"]["metadata"]
_deployment_latency_map = self._get_deployment_latencies_to_alert(
metadata=_metadata
)
if _deployment_latency_map is not None:
request_info += (
f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
)
await self.send_alert(
message=slow_message + request_info,
level="Low",
alert_type="llm_too_slow",
)
async def log_failure_event(self, original_exception: Exception):
pass
async def response_taking_too_long(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
type: Literal["hanging_request", "slow_response"] = "hanging_request",
request_data: Optional[dict] = None,
):
if self.alerting is None or self.alert_types is None:
return
if request_data is not None:
model = request_data.get("model", "")
messages = request_data.get("messages", None)
if messages is None:
# if messages does not exist fallback to "input"
messages = request_data.get("input", None)
# try casting messages to str and get the first 100 characters, else mark as None
try:
messages = str(messages)
messages = messages[:100]
except:
messages = ""
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, request_data=request_data
)
else:
request_info = ""
if type == "hanging_request":
await asyncio.sleep(
self.alerting_threshold
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
if (
request_data is not None
and request_data.get("litellm_status", "") != "success"
and request_data.get("litellm_status", "") != "fail"
):
if request_data.get("deployment", None) is not None and isinstance(
request_data["deployment"], dict
):
_api_base = litellm.get_api_base(
model=model,
optional_params=request_data["deployment"].get(
"litellm_params", {}
),
)
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: {_api_base}"
elif request_data.get("metadata", None) is not None and isinstance(
request_data["metadata"], dict
):
# In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
# in that case we fallback to the api base set in the request metadata
_metadata = request_data["metadata"]
_api_base = _metadata.get("api_base", "")
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: `{_api_base}`"
# only alert hanging responses if they have not been marked as success
alerting_message = (
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
)
# add deployment latencies to alert
_deployment_latency_map = self._get_deployment_latencies_to_alert(
metadata=request_data.get("metadata", {})
)
if _deployment_latency_map is not None:
request_info += f"\nDeployment Latencies\n{_deployment_latency_map}"
await self.send_alert(
message=alerting_message + request_info,
level="Medium",
alert_type="llm_requests_hanging",
)
async def budget_alerts(
self,
type: Literal[
"token_budget",
"user_budget",
"user_and_proxy_budget",
"failed_budgets",
"failed_tracking",
"projected_limit_exceeded",
],
user_max_budget: float,
user_current_spend: float,
user_info=None,
error_message="",
):
if self.alerting is None or self.alert_types is None:
# do nothing if alerting is not switched on
return
if "budget_alerts" not in self.alert_types:
return
_id: str = "default_id" # used for caching
if type == "user_and_proxy_budget":
user_info = dict(user_info)
user_id = user_info["user_id"]
_id = user_id
max_budget = user_info["max_budget"]
spend = user_info["spend"]
user_email = user_info["user_email"]
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
elif type == "token_budget":
token_info = dict(user_info)
token = token_info["token"]
_id = token
spend = token_info["spend"]
max_budget = token_info["max_budget"]
user_id = token_info["user_id"]
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
elif type == "failed_tracking":
user_id = str(user_info)
_id = user_id
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
message = "Failed Tracking Cost for" + user_info
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
return
elif type == "projected_limit_exceeded" and user_info is not None:
"""
Input variables:
user_info = {
"key_alias": key_alias,
"projected_spend": projected_spend,
"projected_exceeded_date": projected_exceeded_date,
}
user_max_budget=soft_limit,
user_current_spend=new_spend
"""
message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}"""
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
return
else:
user_info = str(user_info)
# percent of max_budget left to spend
if user_max_budget > 0:
percent_left = (user_max_budget - user_current_spend) / user_max_budget
else:
percent_left = 0
verbose_proxy_logger.debug(
f"Budget Alerts: Percent left: {percent_left} for {user_info}"
)
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
# - Alert once within 28d period
# - Cache this information
# - Don't re-alert, if alert already sent
_cache: DualCache = self.internal_usage_cache
# check if crossed budget
if user_current_spend >= user_max_budget:
verbose_proxy_logger.debug("Budget Crossed for %s", user_info)
message = "Budget Crossed for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
# check if 5% of max budget is left
if percent_left <= 0.05:
message = "5% budget left for" + user_info
cache_key = "alerting:{}".format(_id)
result = await _cache.async_get_cache(key=cache_key)
if result is None:
await self.send_alert(
message=message, level="Medium", alert_type="budget_alerts"
)
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
return
# check if 15% of max budget is left
if percent_left <= 0.15:
message = "15% budget left for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.send_alert(
message=message, level="Low", alert_type="budget_alerts"
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
return
async def send_alert(
self,
message: str,
level: Literal["Low", "Medium", "High"],
alert_type: Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
],
):
"""
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
- Responses taking too long
- Requests are hanging
- Calls are failing
- DB Read/Writes are failing
- Proxy Close to max budget
- Key Close to max budget
Parameters:
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
message: str - what is the alert about
"""
if self.alerting is None:
return
from datetime import datetime
import json
# Get the current timestamp
current_time = datetime.now().strftime("%H:%M:%S")
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
formatted_message = (
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
)
if _proxy_base_url is not None:
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
# check if we find the slack webhook url in self.alert_to_webhook_url
if (
self.alert_to_webhook_url is not None
and alert_type in self.alert_to_webhook_url
):
slack_webhook_url = self.alert_to_webhook_url[alert_type]
else:
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
if slack_webhook_url is None:
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
payload = {"text": formatted_message}
headers = {"Content-type": "application/json"}
response = await self.async_http_handler.post(
url=slack_webhook_url,
headers=headers,
data=json.dumps(payload),
)
if response.status_code == 200:
pass
else:
print("Error sending slack alert. Error=", response.text) # noqa

View file

@ -298,7 +298,7 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -2,18 +2,13 @@ import os, types
import json import json
from enum import Enum from enum import Enum
import requests, copy import requests, copy
import time, uuid import time
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import ( from .prompt_templates.factory import prompt_factory, custom_prompt
contains_tag, from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
prompt_factory, from .base import BaseLLM
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
import httpx import httpx
@ -21,6 +16,8 @@ class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
class AnthropicError(Exception): class AnthropicError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -37,12 +34,14 @@ class AnthropicError(Exception):
class AnthropicConfig: class AnthropicConfig:
""" """
Reference: https://docs.anthropic.com/claude/reference/complete_post Reference: https://docs.anthropic.com/claude/reference/messages_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
""" """
max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default max_tokens: Optional[int] = (
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
)
stop_sequences: Optional[list] = None stop_sequences: Optional[list] = None
temperature: Optional[int] = None temperature: Optional[int] = None
top_p: Optional[int] = None top_p: Optional[int] = None
@ -52,7 +51,9 @@ class AnthropicConfig:
def __init__( def __init__(
self, self,
max_tokens: Optional[int] = 256, # anthropic requires a default max_tokens: Optional[
int
] = 4096, # You can pass in a value yourself or use the default value 4096
stop_sequences: Optional[list] = None, stop_sequences: Optional[list] = None,
temperature: Optional[int] = None, temperature: Optional[int] = None,
top_p: Optional[int] = None, top_p: Optional[int] = None,
@ -101,124 +102,23 @@ def validate_environment(api_key, user_headers):
return headers return headers
def completion( class AnthropicChatCompletion(BaseLLM):
model: str, def __init__(self) -> None:
messages: list, super().__init__()
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
json_schemas: dict = {}
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
except Exception as e:
raise AnthropicError(status_code=400, message=str(e))
## Load Config
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
## COMPLETION CALL
if (
stream is not None and stream == True and _is_function_call == False
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose(f"makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=stream,
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
def process_response(
self,
model,
response,
model_response,
_is_function_call,
stream,
logging_obj,
api_key,
data,
messages,
print_verbose,
):
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -245,46 +145,40 @@ def completion(
status_code=response.status_code, status_code=response.status_code,
) )
else: else:
text_content = completion_response["content"][0].get("text", None) text_content = ""
## TOOL CALLING - OUTPUT PARSE tool_calls = []
if text_content is not None and contains_tag("invoke", text_content): for content in completion_response["content"]:
function_name = extract_between_tags("tool_name", text_content)[0] if content["type"] == "text":
function_arguments_str = extract_between_tags("invoke", text_content)[ text_content += content["text"]
0 ## TOOL CALLING
].strip() elif content["type"] == "tool_use":
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>" tool_calls.append(
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name
)
_message = litellm.Message(
tool_calls=[
{ {
"id": f"call_{uuid.uuid4()}", "id": content["id"],
"type": "function", "type": "function",
"function": { "function": {
"name": function_name, "name": content["name"],
"arguments": json.dumps(function_arguments), "arguments": json.dumps(content["input"]),
}, },
} }
], )
content=None,
) _message = litellm.Message(
model_response.choices[0].message = _message # type: ignore tool_calls=tool_calls,
model_response._hidden_params["original_response"] = ( content=text_content or None,
text_content # allow user to access raw anthropic tool calling response )
) model_response.choices[0].message = _message # type: ignore
else: model_response._hidden_params["original_response"] = completion_response[
model_response.choices[0].message.content = text_content # type: ignore "content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"] completion_response["stop_reason"]
) )
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call == True and stream is not None and stream == True: if _is_function_call and stream:
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator # return an iterator
streaming_model_response = ModelResponse(stream=True) streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ streaming_model_response.choices[0].finish_reason = model_response.choices[
@ -318,7 +212,7 @@ def completion(
model_response=streaming_model_response model_response=streaming_model_response
) )
print_verbose( print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -337,11 +231,278 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=total_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
async def acompletion_stream_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data), stream=True
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return streamwrapper
async def acompletion_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
)
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
except Exception as e:
raise AnthropicError(status_code=400, message=str(e))
## Load Config
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
headers["anthropic-beta"] = "tools-2024-04-04"
anthropic_tools = []
for tool in optional_params["tools"]:
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion == True:
if (
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async anthropic streaming POST request")
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
return self.acompletion_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
if (
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=stream,
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.iter_lines()
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return streaming_response
else:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
)
def embedding(self):
# logic for parsing in - calling - parsing out model embedding calls
pass
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, model_response): def __init__(self, model_response):
@ -367,8 +528,3 @@ class ModelResponseIterator:
raise StopAsyncIteration raise StopAsyncIteration
self.is_done = True self.is_done = True
return self.model_response return self.model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -4,10 +4,12 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
from .base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
@ -94,91 +96,13 @@ def validate_environment(api_key, user_headers):
return headers return headers
def completion( class AnthropicTextCompletion(BaseLLM):
model: str, def __init__(self) -> None:
messages: list, super().__init__()
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config def process_response(
config = litellm.AnthropicTextConfig.get_config() self, model_response: ModelResponse, response, encoding, prompt: str, model: str
for k, v in config.items(): ):
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = response.json()
@ -213,10 +137,208 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response return model_response
async def async_completion(
self,
model: str,
model_response: ModelResponse,
api_base: str,
logging_obj,
encoding,
headers: dict,
data: dict,
client=None,
):
if client is None:
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
def embedding(): response = await client.post(api_base, headers=headers, data=json.dumps(data))
# logic for parsing in - calling - parsing out model embedding calls
pass if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=data["prompt"],
api_key=headers.get("x-api-key"),
original_response=response.text,
additional_args={"complete_input_dict": data},
)
response = self.process_response(
model_response=model_response,
response=response,
encoding=encoding,
prompt=data["prompt"],
model=model,
)
return response
async def async_streaming(
self,
model: str,
api_base: str,
logging_obj,
headers: dict,
data: Optional[dict],
client=None,
):
if client is None:
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic_text",
logging_obj=logging_obj,
)
return streamwrapper
def completion(
self,
model: str,
messages: list,
api_base: str,
acompletion: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
):
headers = validate_environment(api_key, headers)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config
config = litellm.AnthropicTextConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
if acompletion == True:
return self.async_streaming(
model=model,
api_base=api_base,
logging_obj=logging_obj,
headers=headers,
data=data,
client=None,
)
if client is None:
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
# stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic_text",
logging_obj=logging_obj,
)
return stream_response
elif acompletion == True:
return self.async_completion(
model=model,
model_response=model_response,
api_base=api_base,
logging_obj=logging_obj,
encoding=encoding,
headers=headers,
data=data,
client=client,
)
else:
if client is None:
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = client.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
response = self.process_response(
model_response=model_response,
response=response,
encoding=encoding,
prompt=data["prompt"],
model=model,
)
return response
def embedding(self):
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -799,6 +799,7 @@ class AzureChatCompletion(BaseLLM):
optional_params: dict, optional_params: dict,
model_response: TranscriptionResponse, model_response: TranscriptionResponse,
timeout: float, timeout: float,
max_retries: int,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
@ -817,8 +818,6 @@ class AzureChatCompletion(BaseLLM):
"timeout": timeout, "timeout": timeout,
} }
max_retries = optional_params.pop("max_retries", None)
azure_client_params = select_azure_base_url_or_endpoint( azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params azure_client_params=azure_client_params
) )

View file

@ -8,6 +8,7 @@ from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
convert_to_model_response_object, convert_to_model_response_object,
TranscriptionResponse, TranscriptionResponse,
TextCompletionResponse,
) )
from typing import Callable, Optional, BinaryIO from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig from litellm import OpenAIConfig
@ -15,11 +16,11 @@ import litellm, json
import httpx import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
from ..llms.openai import OpenAITextCompletion from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
import uuid import uuid
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
openai_text_completion = OpenAITextCompletion() openai_text_completion_config = OpenAITextCompletionConfig()
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
@ -300,9 +301,11 @@ class AzureTextCompletion(BaseLLM):
"api_base": api_base, "api_base": api_base,
}, },
) )
return openai_text_completion.convert_to_model_response_object( return (
response_object=stringified_response, openai_text_completion_config.convert_to_chat_model_response_object(
model_response_object=model_response, response_object=TextCompletionResponse(**stringified_response),
model_response_object=model_response,
)
) )
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
@ -373,7 +376,7 @@ class AzureTextCompletion(BaseLLM):
}, },
) )
response = await azure_client.completions.create(**data, timeout=timeout) response = await azure_client.completions.create(**data, timeout=timeout)
return openai_text_completion.convert_to_model_response_object( return openai_text_completion_config.convert_to_chat_model_response_object(
response_object=response.model_dump(), response_object=response.model_dump(),
model_response_object=model_response, model_response_object=model_response,
) )

View file

@ -55,9 +55,11 @@ def completion(
"inputs": prompt, "inputs": prompt,
"prompt": prompt, "prompt": prompt,
"parameters": optional_params, "parameters": optional_params,
"stream": True "stream": (
if "stream" in optional_params and optional_params["stream"] == True True
else False, if "stream" in optional_params and optional_params["stream"] == True
else False
),
} }
## LOGGING ## LOGGING
@ -71,9 +73,11 @@ def completion(
completion_url_fragment_1 + model + completion_url_fragment_2, completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=True stream=(
if "stream" in optional_params and optional_params["stream"] == True True
else False, if "stream" in optional_params and optional_params["stream"] == True
else False
),
) )
if "text/event-stream" in response.headers["Content-Type"] or ( if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True "stream" in optional_params and optional_params["stream"] == True
@ -102,28 +106,28 @@ def completion(
and "data" in completion_response["model_output"] and "data" in completion_response["model_output"]
and isinstance(completion_response["model_output"]["data"], list) and isinstance(completion_response["model_output"]["data"], list)
): ):
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = (
"content" completion_response["model_output"]["data"][0]
] = completion_response["model_output"]["data"][0] )
elif isinstance(completion_response["model_output"], str): elif isinstance(completion_response["model_output"], str):
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = (
"content" completion_response["model_output"]
] = completion_response["model_output"] )
elif "completion" in completion_response and isinstance( elif "completion" in completion_response and isinstance(
completion_response["completion"], str completion_response["completion"], str
): ):
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = (
"content" completion_response["completion"]
] = completion_response["completion"] )
elif isinstance(completion_response, list) and len(completion_response) > 0: elif isinstance(completion_response, list) and len(completion_response) > 0:
if "generated_text" not in completion_response: if "generated_text" not in completion_response:
raise BasetenError( raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}", message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code, status_code=response.status_code,
) )
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = (
"content" completion_response[0]["generated_text"]
] = completion_response[0]["generated_text"] )
## GETTING LOGPROBS ## GETTING LOGPROBS
if ( if (
"details" in completion_response[0] "details" in completion_response[0]
@ -155,7 +159,8 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage
setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -653,6 +653,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
prompt = prompt_factory( prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock" model=model, messages=messages, custom_llm_provider="bedrock"
) )
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
else: else:
prompt = "" prompt = ""
for message in messages: for message in messages:
@ -746,7 +750,7 @@ def completion(
] ]
# Format rest of message according to anthropic guidelines # Format rest of message according to anthropic guidelines
messages = prompt_factory( messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic" model=model, messages=messages, custom_llm_provider="anthropic_xml"
) )
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config() config = litellm.AmazonAnthropicClaude3Config.get_config()
@ -1008,7 +1012,7 @@ def completion(
) )
streaming_choice.delta = delta_obj streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice] streaming_model_response.choices = [streaming_choice]
completion_stream = model_response_iterator( completion_stream = ModelResponseIterator(
model_response=streaming_model_response model_response=streaming_model_response
) )
print_verbose( print_verbose(
@ -1028,7 +1032,7 @@ def completion(
total_tokens=response_body["usage"]["input_tokens"] total_tokens=response_body["usage"]["input_tokens"]
+ response_body["usage"]["output_tokens"], + response_body["usage"]["output_tokens"],
) )
model_response.usage = _usage setattr(model_response, "usage", _usage)
else: else:
outputText = response_body["completion"] outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"] model_response["finish_reason"] = response_body["stop_reason"]
@ -1071,8 +1075,10 @@ def completion(
status_code=response_metadata.get("HTTPStatusCode", 500), status_code=response_metadata.get("HTTPStatusCode", 500),
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - bedrock charges on time, not tokens - have some mapping of cost here.
if getattr(model_response.usage, "total_tokens", None) is None: if not hasattr(model_response, "usage"):
setattr(model_response, "usage", Usage())
if getattr(model_response.usage, "total_tokens", None) is None: # type: ignore
prompt_tokens = response_metadata.get( prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt)) "x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
) )
@ -1089,7 +1095,7 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
@ -1109,8 +1115,30 @@ def completion(
raise BedrockError(status_code=500, message=traceback.format_exc()) raise BedrockError(status_code=500, message=traceback.format_exc())
async def model_response_iterator(model_response): class ModelResponseIterator:
yield model_response def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response
def _embedding_func_single( def _embedding_func_single(

View file

@ -167,7 +167,7 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -237,7 +237,7 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -43,6 +43,7 @@ class CohereChatConfig:
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens. presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking. tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools. tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
""" """
preamble: Optional[str] = None preamble: Optional[str] = None
@ -62,6 +63,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None presence_penalty: Optional[int] = None
tools: Optional[list] = None tools: Optional[list] = None
tool_results: Optional[list] = None tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__( def __init__(
self, self,
@ -82,6 +84,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None, presence_penalty: Optional[int] = None,
tools: Optional[list] = None, tools: Optional[list] = None,
tool_results: Optional[list] = None, tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
@ -302,5 +305,5 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -0,0 +1,96 @@
import httpx, asyncio
from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
class AsyncHTTPHandler:
def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
):
# Create a client with a connection pool
self.client = httpx.AsyncClient(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
async def close(self):
# Close the client when you're done with it
await self.client.aclose()
async def __aenter__(self):
return self.client
async def __aexit__(self):
# close the client when exiting
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = await self.client.get(url, params=params, headers=headers)
return response
async def post(
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
)
response = await self.client.send(req, stream=stream)
return response
def __del__(self) -> None:
try:
asyncio.get_running_loop().create_task(self.close())
except Exception:
pass
class HTTPHandler:
def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
):
# Create a client with a connection pool
self.client = httpx.Client(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
def close(self):
# Close the client when you're done with it
self.client.close()
def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = self.client.get(url, params=params, headers=headers)
return response
def post(
self,
url: str,
data: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
response = self.client.post(url, data=data, params=params, headers=headers)
return response
def __del__(self) -> None:
try:
self.close()
except Exception:
pass

View file

@ -6,7 +6,8 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm import litellm
import sys, httpx import sys, httpx
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version
class GeminiError(Exception): class GeminiError(Exception):
@ -103,6 +104,13 @@ class TextStreamer:
break break
def supports_system_instruction():
import google.generativeai as genai
gemini_pkg_version = Version(genai.__version__)
return gemini_pkg_version >= Version("0.5.0")
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -124,7 +132,7 @@ def completion(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai" "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
) )
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
system_prompt = ""
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
@ -135,6 +143,7 @@ def completion(
messages=messages, messages=messages,
) )
else: else:
system_prompt, messages = get_system_prompt(messages=messages)
prompt = prompt_factory( prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini" model=model, messages=messages, custom_llm_provider="gemini"
) )
@ -162,11 +171,20 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}}, additional_args={
"complete_input_dict": {
"inference_params": inference_params,
"system_prompt": system_prompt,
}
},
) )
## COMPLETION CALL ## COMPLETION CALL
try: try:
_model = genai.GenerativeModel(f"models/{model}") _params = {"model_name": "models/{}".format(model)}
_system_instruction = supports_system_instruction()
if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params)
if stream == True: if stream == True:
if acompletion == True: if acompletion == True:
@ -213,11 +231,12 @@ def completion(
encoding=encoding, encoding=encoding,
) )
else: else:
response = _model.generate_content( params = {
contents=prompt, "contents": prompt,
generation_config=genai.types.GenerationConfig(**inference_params), "generation_config": genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings, "safety_settings": safety_settings,
) }
response = _model.generate_content(**params)
except Exception as e: except Exception as e:
raise GeminiError( raise GeminiError(
message=str(e), message=str(e),
@ -292,7 +311,7 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -152,9 +152,9 @@ def completion(
else: else:
try: try:
if len(completion_response["answer"]) > 0: if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = (
"content" completion_response["answer"]
] = completion_response["answer"] )
except Exception as e: except Exception as e:
raise MaritalkError( raise MaritalkError(
message=response.text, status_code=response.status_code message=response.text, status_code=response.status_code
@ -174,7 +174,7 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -185,9 +185,9 @@ def completion(
else: else:
try: try:
if len(completion_response["generated_text"]) > 0: if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = (
"content" completion_response["generated_text"]
] = completion_response["generated_text"] )
except: except:
raise NLPCloudError( raise NLPCloudError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
@ -205,7 +205,7 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -20,7 +20,7 @@ class OllamaError(Exception):
class OllamaConfig: class OllamaConfig:
""" """
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters: The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
@ -69,7 +69,7 @@ class OllamaConfig:
repeat_penalty: Optional[float] = None repeat_penalty: Optional[float] = None
temperature: Optional[float] = None temperature: Optional[float] = None
stop: Optional[list] = ( stop: Optional[list] = (
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442 None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
) )
tfs_z: Optional[float] = None tfs_z: Optional[float] = None
num_predict: Optional[int] = None num_predict: Optional[int] = None
@ -228,8 +228,8 @@ def get_ollama_response(
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json["eval_count"] completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage( model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -330,8 +330,8 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
] ]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data["model"] model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json["eval_count"] completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage( model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -20,7 +20,7 @@ class OllamaError(Exception):
class OllamaChatConfig: class OllamaChatConfig:
""" """
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters: The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
@ -69,7 +69,7 @@ class OllamaChatConfig:
repeat_penalty: Optional[float] = None repeat_penalty: Optional[float] = None
temperature: Optional[float] = None temperature: Optional[float] = None
stop: Optional[list] = ( stop: Optional[list] = (
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442 None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
) )
tfs_z: Optional[float] = None tfs_z: Optional[float] = None
num_predict: Optional[int] = None num_predict: Optional[int] = None
@ -148,7 +148,7 @@ class OllamaChatConfig:
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "frequency_penalty": if param == "frequency_penalty":
optional_params["repeat_penalty"] = param optional_params["repeat_penalty"] = value
if param == "stop": if param == "stop":
optional_params["stop"] = value optional_params["stop"] = value
if param == "response_format" and value["type"] == "json_object": if param == "response_format" and value["type"] == "json_object":
@ -184,6 +184,7 @@ class OllamaChatConfig:
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
api_base="http://localhost:11434", api_base="http://localhost:11434",
api_key: Optional[str] = None,
model="llama2", model="llama2",
messages=None, messages=None,
optional_params=None, optional_params=None,
@ -236,6 +237,7 @@ def get_ollama_response(
if stream == True: if stream == True:
response = ollama_async_streaming( response = ollama_async_streaming(
url=url, url=url,
api_key=api_key,
data=data, data=data,
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
@ -244,6 +246,7 @@ def get_ollama_response(
else: else:
response = ollama_acompletion( response = ollama_acompletion(
url=url, url=url,
api_key=api_key,
data=data, data=data,
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
@ -252,12 +255,17 @@ def get_ollama_response(
) )
return response return response
elif stream == True: elif stream == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj
)
response = requests.post( _request = {
url=f"{url}", "url": f"{url}",
json=data, "json": data,
) }
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
response = requests.post(**_request) # type: ignore
if response.status_code != 200: if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text) raise OllamaError(status_code=response.status_code, message=response.text)
@ -307,10 +315,16 @@ def get_ollama_response(
return model_response return model_response
def ollama_completion_stream(url, data, logging_obj): def ollama_completion_stream(url, api_key, data, logging_obj):
with httpx.stream( _request = {
url=url, json=data, method="POST", timeout=litellm.request_timeout "url": f"{url}",
) as response: "json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
with httpx.stream(**_request) as response:
try: try:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError( raise OllamaError(
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
raise e raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): async def ollama_async_streaming(
url, api_key, data, model_response, encoding, logging_obj
):
try: try:
client = httpx.AsyncClient() client = httpx.AsyncClient()
async with client.stream( _request = {
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout "url": f"{url}",
) as response: "json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
async with client.stream(**_request) as response:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError( raise OllamaError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async def ollama_acompletion( async def ollama_acompletion(
url, data, model_response, encoding, logging_obj, function_name url,
api_key: Optional[str],
data,
model_response,
encoding,
logging_obj,
function_name,
): ):
data["stream"] = False data["stream"] = False
try: try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data) _request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
resp = await session.post(**_request)
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()

View file

@ -99,9 +99,9 @@ def completion(
) )
else: else:
try: try:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = (
"content" completion_response["choices"][0]["message"]["content"]
] = completion_response["choices"][0]["message"]["content"] )
except: except:
raise OobaboogaError( raise OobaboogaError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
@ -115,7 +115,7 @@ def completion(
completion_tokens=completion_response["usage"]["completion_tokens"], completion_tokens=completion_response["usage"]["completion_tokens"],
total_tokens=completion_response["usage"]["total_tokens"], total_tokens=completion_response["usage"]["total_tokens"],
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -10,6 +10,7 @@ from litellm.utils import (
convert_to_model_response_object, convert_to_model_response_object,
Usage, Usage,
TranscriptionResponse, TranscriptionResponse,
TextCompletionResponse,
) )
from typing import Callable, Optional from typing import Callable, Optional
import aiohttp, requests import aiohttp, requests
@ -200,6 +201,43 @@ class OpenAITextCompletionConfig:
and v is not None and v is not None
} }
def convert_to_chat_model_response_object(
self,
response_object: Optional[TextCompletionResponse] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["text"],
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = (
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
)
return model_response_object
except Exception as e:
raise e
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
@ -785,10 +823,10 @@ class OpenAIChatCompletion(BaseLLM):
optional_params: dict, optional_params: dict,
model_response: TranscriptionResponse, model_response: TranscriptionResponse,
timeout: float, timeout: float,
max_retries: int,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
max_retries=None,
logging_obj=None, logging_obj=None,
atranscription: bool = False, atranscription: bool = False,
): ):
@ -962,40 +1000,6 @@ class OpenAITextCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
def convert_to_model_response_object(
self,
response_object: Optional[dict] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant")
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
model_response_object.usage = response_object["usage"]
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = (
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
)
return model_response_object
except Exception as e:
raise e
def completion( def completion(
self, self,
model_response: ModelResponse, model_response: ModelResponse,
@ -1010,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
): ):
super().completion() super().completion()
@ -1020,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
if model is None or messages is None: if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
api_base = f"{api_base}/completions"
if ( if (
len(messages) > 0 len(messages) > 0
and "content" in messages[0] and "content" in messages[0]
@ -1029,12 +1033,12 @@ class OpenAITextCompletion(BaseLLM):
): ):
prompt = messages[0]["content"] prompt = messages[0]["content"]
else: else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore prompt = [message["content"] for message in messages] # type: ignore
# don't send max retries to the api, if set # don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = {"model": model, "prompt": prompt, **optional_params} data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
@ -1050,38 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
return self.async_streaming( return self.async_streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,
api_key=api_key,
data=data, data=data,
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout, timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
) )
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming( return self.streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,
api_key=api_key,
data=data, data=data,
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout, timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
) )
else: else:
response = httpx.post( if client is None:
url=f"{api_base}", json=data, headers=headers, timeout=timeout openai_client = OpenAI(
) api_key=api_key,
if response.status_code != 200: base_url=api_base,
raise OpenAIError( http_client=litellm.client_session,
status_code=response.status_code, message=response.text timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
) )
else:
openai_client = client
response = openai_client.completions.create(**data) # type: ignore
response_json = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response, original_response=response_json,
additional_args={ additional_args={
"headers": headers, "headers": headers,
"api_base": api_base, "api_base": api_base,
@ -1089,10 +1108,7 @@ class OpenAITextCompletion(BaseLLM):
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object( return TextCompletionResponse(**response_json)
response_object=response.json(),
model_response_object=model_response,
)
except Exception as e: except Exception as e:
raise e raise e
@ -1107,101 +1123,112 @@ class OpenAITextCompletion(BaseLLM):
api_key: str, api_key: str,
model: str, model: str,
timeout: float, timeout: float,
max_retries=None,
organization: Optional[str] = None,
client=None,
): ):
async with httpx.AsyncClient(timeout=timeout) as client: try:
try: if client is None:
response = await client.post( openai_aclient = AsyncOpenAI(
api_base,
json=data,
headers=headers,
timeout=litellm.request_timeout,
)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key, api_key=api_key,
original_response=response, base_url=api_base,
additional_args={ http_client=litellm.aclient_session,
"headers": headers, timeout=timeout,
"api_base": api_base, max_retries=max_retries,
}, organization=organization,
) )
else:
openai_aclient = client
## RESPONSE OBJECT response = await openai_aclient.completions.create(**data)
return self.convert_to_model_response_object( response_json = response.model_dump()
response_object=response_json, model_response_object=model_response ## LOGGING
) logging_obj.post_call(
except Exception as e: input=prompt,
raise e api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
response_obj = TextCompletionResponse(**response_json)
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
raise e
def streaming( def streaming(
self, self,
logging_obj, logging_obj,
api_base: str, api_key: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float, timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
): ):
with httpx.stream( if client is None:
url=f"{api_base}", openai_client = OpenAI(
json=data, api_key=api_key,
headers=headers, base_url=api_base,
method="POST", http_client=litellm.client_session,
timeout=timeout, timeout=timeout,
) as response: max_retries=max_retries, # type: ignore
if response.status_code != 200: organization=organization,
raise OpenAIError(
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
) )
for transformed_chunk in streamwrapper: else:
yield transformed_chunk openai_client = client
response = openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
for chunk in streamwrapper:
yield chunk
async def async_streaming( async def async_streaming(
self, self,
logging_obj, logging_obj,
api_base: str, api_key: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float, timeout: float,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
): ):
client = httpx.AsyncClient() if client is None:
async with client.stream( openai_client = AsyncOpenAI(
url=f"{api_base}", api_key=api_key,
json=data, base_url=api_base,
headers=headers, http_client=litellm.aclient_session,
method="POST", timeout=timeout,
timeout=timeout, max_retries=max_retries,
) as response: organization=organization,
try: )
if response.status_code != 200: else:
raise OpenAIError( openai_client = client
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper( response = await openai_client.completions.create(**data)
completion_stream=response.aiter_lines(),
model=model, streamwrapper = CustomStreamWrapper(
custom_llm_provider="text-completion-openai", completion_stream=response,
logging_obj=logging_obj, model=model,
) custom_llm_provider="text-completion-openai",
async for transformed_chunk in streamwrapper: logging_obj=logging_obj,
yield transformed_chunk )
except Exception as e:
raise e async for transformed_chunk in streamwrapper:
yield transformed_chunk

View file

@ -191,7 +191,7 @@ def completion(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage)
return model_response return model_response

Some files were not shown because too many files have changed in this diff Show more