Merge branch 'BerriAI:main' into fix-anthropic-messages-api
|
@ -8,6 +8,11 @@ jobs:
|
|||
steps:
|
||||
- checkout
|
||||
|
||||
- run:
|
||||
name: Show git commit hash
|
||||
command: |
|
||||
echo "Git commit hash: $CIRCLE_SHA1"
|
||||
|
||||
- run:
|
||||
name: Check if litellm dir was updated or if pyproject.toml was modified
|
||||
command: |
|
||||
|
@ -31,16 +36,17 @@ jobs:
|
|||
pip install "google-generativeai==0.3.2"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
pip install pyarrow
|
||||
pip install "boto3>=1.28.57"
|
||||
pip install "aioboto3>=12.3.0"
|
||||
pip install "boto3==1.34.34"
|
||||
pip install "aioboto3==12.3.0"
|
||||
pip install langchain
|
||||
pip install lunary==0.2.5
|
||||
pip install "langfuse>=2.0.0"
|
||||
pip install "langfuse==2.7.3"
|
||||
pip install numpydoc
|
||||
pip install traceloop-sdk==0.0.69
|
||||
pip install openai
|
||||
pip install prisma
|
||||
pip install "httpx==0.24.1"
|
||||
pip install fastapi
|
||||
pip install "gunicorn==21.2.0"
|
||||
pip install "anyio==3.7.1"
|
||||
pip install "aiodynamo==23.10.1"
|
||||
|
@ -51,6 +57,7 @@ jobs:
|
|||
pip install "pytest-mock==3.12.0"
|
||||
pip install python-multipart
|
||||
pip install google-cloud-aiplatform
|
||||
pip install prometheus-client==0.20.0
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
@ -73,7 +80,7 @@ jobs:
|
|||
name: Linting Testing
|
||||
command: |
|
||||
cd litellm
|
||||
python -m pip install types-requests types-setuptools types-redis
|
||||
python -m pip install types-requests types-setuptools types-redis types-PyYAML
|
||||
if ! python -m mypy . --ignore-missing-imports; then
|
||||
echo "mypy detected errors"
|
||||
exit 1
|
||||
|
@ -123,6 +130,7 @@ jobs:
|
|||
build_and_test:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
resource_class: xlarge
|
||||
working_directory: ~/project
|
||||
steps:
|
||||
- checkout
|
||||
|
@ -182,12 +190,19 @@ jobs:
|
|||
-p 4000:4000 \
|
||||
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||
-e REDIS_HOST=$REDIS_HOST \
|
||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||
-e REDIS_PORT=$REDIS_PORT \
|
||||
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
||||
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
||||
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
||||
-e LANGFUSE_PROJECT1_SECRET=$LANGFUSE_PROJECT1_SECRET \
|
||||
-e LANGFUSE_PROJECT2_SECRET=$LANGFUSE_PROJECT2_SECRET \
|
||||
--name my-app \
|
||||
-v $(pwd)/proxy_server_config.yaml:/app/config.yaml \
|
||||
my-app:latest \
|
||||
|
@ -292,7 +307,7 @@ jobs:
|
|||
-H "Accept: application/vnd.github.v3+json" \
|
||||
-H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
|
||||
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}\"}}"
|
||||
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"v${VERSION}\", \"commit_hash\":\"$CIRCLE_SHA1\"}}"
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
|
|
|
@ -3,11 +3,10 @@ openai
|
|||
python-dotenv
|
||||
tiktoken
|
||||
importlib_metadata
|
||||
baseten
|
||||
cohere
|
||||
redis
|
||||
anthropic
|
||||
orjson
|
||||
pydantic
|
||||
pydantic==1.10.14
|
||||
google-cloud-aiplatform==1.43.0
|
||||
redisvl==0.0.7 # semantic caching
|
|
@ -1,5 +1,5 @@
|
|||
/docs
|
||||
/cookbook
|
||||
/.circleci
|
||||
/.github
|
||||
/tests
|
||||
docs
|
||||
cookbook
|
||||
.circleci
|
||||
.github
|
||||
tests
|
||||
|
|
24
.github/workflows/ghcr_deploy.yml
vendored
|
@ -5,6 +5,13 @@ on:
|
|||
inputs:
|
||||
tag:
|
||||
description: "The tag version you want to build"
|
||||
release_type:
|
||||
description: "The release type you want to build. Can be 'latest', 'stable', 'dev'"
|
||||
type: string
|
||||
default: "latest"
|
||||
commit_hash:
|
||||
description: "Commit hash"
|
||||
required: true
|
||||
|
||||
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
||||
env:
|
||||
|
@ -85,9 +92,9 @@ jobs:
|
|||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@4976231911ebf5f32aad765192d35f942aa48cb8
|
||||
with:
|
||||
context: .
|
||||
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-latest # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.release_type }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
||||
|
||||
|
@ -121,10 +128,10 @@ jobs:
|
|||
- name: Build and push Database Docker image
|
||||
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
|
||||
with:
|
||||
context: .
|
||||
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
|
||||
file: Dockerfile.database
|
||||
push: true
|
||||
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-latest
|
||||
tags: ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-database.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
labels: ${{ steps.meta-database.outputs.labels }}
|
||||
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
||||
|
||||
|
@ -158,11 +165,10 @@ jobs:
|
|||
- name: Build and push Database Docker image
|
||||
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
|
||||
with:
|
||||
context: .
|
||||
context: https://github.com/BerriAI/litellm.git#${{ github.event.inputs.commit_hash}}
|
||||
file: ./litellm-js/spend-logs/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-latest
|
||||
labels: ${{ steps.meta-spend-logs.outputs.labels }}
|
||||
tags: ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}, ${{ steps.meta-spend-logs.outputs.tags }}-${{ github.event.inputs.release_type }}
|
||||
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
||||
|
||||
build-and-push-helm-chart:
|
||||
|
@ -236,10 +242,13 @@ jobs:
|
|||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
script: |
|
||||
const commitHash = "${{ github.event.inputs.commit_hash}}";
|
||||
console.log("Commit Hash:", commitHash); // Add this line for debugging
|
||||
try {
|
||||
const response = await github.rest.repos.createRelease({
|
||||
draft: false,
|
||||
generate_release_notes: true,
|
||||
target_commitish: commitHash,
|
||||
name: process.env.RELEASE_TAG,
|
||||
owner: context.repo.owner,
|
||||
prerelease: false,
|
||||
|
@ -288,4 +297,3 @@ jobs:
|
|||
}
|
||||
]
|
||||
}' $WEBHOOK_URL
|
||||
|
||||
|
|
3
.github/workflows/interpret_load_test.py
vendored
|
@ -77,6 +77,9 @@ if __name__ == "__main__":
|
|||
new_release_body = (
|
||||
existing_release_body
|
||||
+ "\n\n"
|
||||
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
||||
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
||||
+ "\n\n"
|
||||
+ "## Load Test LiteLLM Proxy Results"
|
||||
+ "\n\n"
|
||||
+ markdown_table
|
||||
|
|
2
.github/workflows/locustfile.py
vendored
|
@ -10,7 +10,7 @@ class MyUser(HttpUser):
|
|||
def chat_completion(self):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer sk-gUvTeN9g0sgHBMf9HeCaqA",
|
||||
"Authorization": f"Bearer sk-S2-EZTUUDY0EmM6-Fy0Fyw",
|
||||
# Include any additional headers you may need for authentication, etc.
|
||||
}
|
||||
|
||||
|
|
6
.gitignore
vendored
|
@ -45,3 +45,9 @@ deploy/charts/litellm/charts/*
|
|||
deploy/charts/*.tgz
|
||||
litellm/proxy/vertex_key.json
|
||||
**/.vim/
|
||||
/node_modules
|
||||
kub.yaml
|
||||
loadtest_kub.yaml
|
||||
litellm/proxy/_new_secret_config.yaml
|
||||
litellm/proxy/_new_secret_config.yaml
|
||||
litellm/proxy/_super_secret_config.yaml
|
||||
|
|
|
@ -70,5 +70,4 @@ EXPOSE 4000/tcp
|
|||
ENTRYPOINT ["litellm"]
|
||||
|
||||
# Append "--detailed_debug" to the end of CMD to view detailed debug logs
|
||||
# CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
||||
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
||||
CMD ["--port", "4000"]
|
||||
|
|
14
README.md
|
@ -5,7 +5,7 @@
|
|||
<p align="center">Call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, etc.]
|
||||
<br>
|
||||
</p>
|
||||
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
|
||||
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">OpenAI Proxy Server</a> | <a href="https://docs.litellm.ai/docs/hosted" target="_blank"> Hosted Proxy (Preview)</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
|
||||
<h4 align="center">
|
||||
<a href="https://pypi.org/project/litellm/" target="_blank">
|
||||
<img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version">
|
||||
|
@ -32,9 +32,9 @@ LiteLLM manages:
|
|||
- Set Budgets & Rate limits per project, api key, model [OpenAI Proxy Server](https://docs.litellm.ai/docs/simple_proxy)
|
||||
|
||||
[**Jump to OpenAI Proxy Docs**](https://github.com/BerriAI/litellm?tab=readme-ov-file#openai-proxy---docs) <br>
|
||||
[**Jump to Supported LLM Providers**](https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-provider-docs)
|
||||
[**Jump to Supported LLM Providers**](https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-providers-docs)
|
||||
|
||||
🚨 **Stable Release:** v1.34.1
|
||||
🚨 **Stable Release:** Use docker images with: `main-stable` tag. These run through 12 hr load tests (1k req./min).
|
||||
|
||||
Support for more providers. Missing a provider or LLM Platform, raise a [feature request](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+).
|
||||
|
||||
|
@ -128,7 +128,9 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
|
|||
|
||||
# OpenAI Proxy - ([Docs](https://docs.litellm.ai/docs/simple_proxy))
|
||||
|
||||
Set Budgets & Rate limits across multiple projects
|
||||
Track spend + Load Balance across multiple projects
|
||||
|
||||
[Hosted Proxy (Preview)](https://docs.litellm.ai/docs/hosted)
|
||||
|
||||
The proxy provides:
|
||||
|
||||
|
@ -205,7 +207,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
|||
| [aws - bedrock](https://docs.litellm.ai/docs/providers/bedrock) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [google - vertex_ai [Gemini]](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [google AI Studio - gemini](https://docs.litellm.ai/docs/providers/gemini) | ✅ | | ✅ | | |
|
||||
| [google AI Studio - gemini](https://docs.litellm.ai/docs/providers/gemini) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [cloudflare AI Workers](https://docs.litellm.ai/docs/providers/cloudflare_workers) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [cohere](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
@ -220,7 +222,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
|||
| [nlp_cloud](https://docs.litellm.ai/docs/providers/nlp_cloud) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [aleph alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [petals](https://docs.litellm.ai/docs/providers/petals) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
||||
|
|
204
cookbook/Proxy_Batch_Users.ipynb
vendored
Normal file
|
@ -0,0 +1,204 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "680oRk1af-xJ"
|
||||
},
|
||||
"source": [
|
||||
"# Environment Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "X7TgJFn8f88p"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import csv\n",
|
||||
"from typing import Optional\n",
|
||||
"import httpx, json\n",
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"proxy_base_url = \"http://0.0.0.0:4000\" # 👈 SET TO PROXY URL\n",
|
||||
"master_key = \"sk-1234\" # 👈 SET TO PROXY MASTER KEY"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rauw8EOhgBz5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## GLOBAL HTTP CLIENT ## - faster http calls\n",
|
||||
"class HTTPHandler:\n",
|
||||
" def __init__(self, concurrent_limit=1000):\n",
|
||||
" # Create a client with a connection pool\n",
|
||||
" self.client = httpx.AsyncClient(\n",
|
||||
" limits=httpx.Limits(\n",
|
||||
" max_connections=concurrent_limit,\n",
|
||||
" max_keepalive_connections=concurrent_limit,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" async def close(self):\n",
|
||||
" # Close the client when you're done with it\n",
|
||||
" await self.client.aclose()\n",
|
||||
"\n",
|
||||
" async def get(\n",
|
||||
" self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None\n",
|
||||
" ):\n",
|
||||
" response = await self.client.get(url, params=params, headers=headers)\n",
|
||||
" return response\n",
|
||||
"\n",
|
||||
" async def post(\n",
|
||||
" self,\n",
|
||||
" url: str,\n",
|
||||
" data: Optional[dict] = None,\n",
|
||||
" params: Optional[dict] = None,\n",
|
||||
" headers: Optional[dict] = None,\n",
|
||||
" ):\n",
|
||||
" try:\n",
|
||||
" response = await self.client.post(\n",
|
||||
" url, data=data, params=params, headers=headers\n",
|
||||
" )\n",
|
||||
" return response\n",
|
||||
" except Exception as e:\n",
|
||||
" raise e\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7LXN8zaLgOie"
|
||||
},
|
||||
"source": [
|
||||
"# Import Sheet\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Format: | ID | Name | Max Budget |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "oiED0usegPGf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async def import_sheet():\n",
|
||||
" tasks = []\n",
|
||||
" http_client = HTTPHandler()\n",
|
||||
" with open('my-batch-sheet.csv', 'r') as file:\n",
|
||||
" csv_reader = csv.DictReader(file)\n",
|
||||
" for row in csv_reader:\n",
|
||||
" task = create_user(client=http_client, user_id=row['ID'], max_budget=row['Max Budget'], user_name=row['Name'])\n",
|
||||
" tasks.append(task)\n",
|
||||
" # print(f\"ID: {row['ID']}, Name: {row['Name']}, Max Budget: {row['Max Budget']}\")\n",
|
||||
"\n",
|
||||
" keys = await asyncio.gather(*tasks)\n",
|
||||
"\n",
|
||||
" with open('my-batch-sheet_new.csv', 'w', newline='') as new_file:\n",
|
||||
" fieldnames = ['ID', 'Name', 'Max Budget', 'keys']\n",
|
||||
" csv_writer = csv.DictWriter(new_file, fieldnames=fieldnames)\n",
|
||||
" csv_writer.writeheader()\n",
|
||||
"\n",
|
||||
" with open('my-batch-sheet.csv', 'r') as file:\n",
|
||||
" csv_reader = csv.DictReader(file)\n",
|
||||
" for i, row in enumerate(csv_reader):\n",
|
||||
" row['keys'] = keys[i] # Add the 'keys' value from the corresponding task result\n",
|
||||
" csv_writer.writerow(row)\n",
|
||||
"\n",
|
||||
" await http_client.close()\n",
|
||||
"\n",
|
||||
"asyncio.run(import_sheet())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E7M0Li_UgJeZ"
|
||||
},
|
||||
"source": [
|
||||
"# Create Users + Keys\n",
|
||||
"\n",
|
||||
"- Creates a user\n",
|
||||
"- Creates a key with max budget"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "NZudRFujf7j-"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"async def create_key_with_alias(client: HTTPHandler, user_id: str, max_budget: float):\n",
|
||||
" global proxy_base_url\n",
|
||||
" if not proxy_base_url.endswith(\"/\"):\n",
|
||||
" proxy_base_url += \"/\"\n",
|
||||
" url = proxy_base_url + \"key/generate\"\n",
|
||||
"\n",
|
||||
" # call /key/generate\n",
|
||||
" print(\"CALLING /KEY/GENERATE\")\n",
|
||||
" response = await client.post(\n",
|
||||
" url=url,\n",
|
||||
" headers={\"Authorization\": f\"Bearer {master_key}\"},\n",
|
||||
" data=json.dumps({\n",
|
||||
" \"user_id\": user_id,\n",
|
||||
" \"key_alias\": f\"{user_id}-key\",\n",
|
||||
" \"max_budget\": max_budget # 👈 KEY CHANGE: SETS MAX BUDGET PER KEY\n",
|
||||
" })\n",
|
||||
" )\n",
|
||||
" print(f\"response: {response.text}\")\n",
|
||||
" return response.json()[\"key\"]\n",
|
||||
"\n",
|
||||
"async def create_user(client: HTTPHandler, user_id: str, max_budget: float, user_name: str):\n",
|
||||
" \"\"\"\n",
|
||||
" - call /user/new\n",
|
||||
" - create key for user\n",
|
||||
" \"\"\"\n",
|
||||
" global proxy_base_url\n",
|
||||
" if not proxy_base_url.endswith(\"/\"):\n",
|
||||
" proxy_base_url += \"/\"\n",
|
||||
" url = proxy_base_url + \"user/new\"\n",
|
||||
"\n",
|
||||
" # call /user/new\n",
|
||||
" await client.post(\n",
|
||||
" url=url,\n",
|
||||
" headers={\"Authorization\": f\"Bearer {master_key}\"},\n",
|
||||
" data=json.dumps({\n",
|
||||
" \"user_id\": user_id,\n",
|
||||
" \"user_alias\": user_name,\n",
|
||||
" \"auto_create_key\": False,\n",
|
||||
" # \"max_budget\": max_budget # 👈 [OPTIONAL] Sets max budget per user (if you want to set a max budget across keys)\n",
|
||||
" })\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # create key for user\n",
|
||||
" return await create_key_with_alias(client=client, user_id=user_id, max_budget=max_budget)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
|
@ -87,6 +87,7 @@
|
|||
| command-light | cohere | 0.00003 |
|
||||
| command-medium-beta | cohere | 0.00003 |
|
||||
| command-xlarge-beta | cohere | 0.00003 |
|
||||
| command-r-plus| cohere | 0.000018 |
|
||||
| j2-ultra | ai21 | 0.00003 |
|
||||
| ai21.j2-ultra-v1 | bedrock | 0.0000376 |
|
||||
| gpt-4-1106-preview | openai | 0.00004 |
|
||||
|
|
73
cookbook/misc/config.yaml
Normal file
|
@ -0,0 +1,73 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||
- model_name: gpt-3.5-turbo-large
|
||||
litellm_params:
|
||||
model: "gpt-3.5-turbo-1106"
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
rpm: 480
|
||||
timeout: 300
|
||||
stream_timeout: 60
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||
rpm: 480
|
||||
timeout: 300
|
||||
stream_timeout: 60
|
||||
- model_name: sagemaker-completion-model
|
||||
litellm_params:
|
||||
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
||||
input_cost_per_second: 0.000420
|
||||
- model_name: text-embedding-ada-002
|
||||
litellm_params:
|
||||
model: azure/azure-embedding-model
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
model_info:
|
||||
mode: embedding
|
||||
base_model: text-embedding-ada-002
|
||||
- model_name: dall-e-2
|
||||
litellm_params:
|
||||
model: azure/
|
||||
api_version: 2023-06-01-preview
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
- model_name: openai-dall-e-3
|
||||
litellm_params:
|
||||
model: dall-e-3
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
# max_budget: 100
|
||||
# budget_duration: 30d
|
||||
num_retries: 5
|
||||
request_timeout: 600
|
||||
telemetry: False
|
||||
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||
store_model_in_db: True
|
||||
proxy_budget_rescheduler_min_time: 60
|
||||
proxy_budget_rescheduler_max_time: 64
|
||||
proxy_batch_write_at: 1
|
||||
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
|
||||
|
||||
# environment_variables:
|
||||
# settings for using redis caching
|
||||
# REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
||||
# REDIS_PORT: "16337"
|
||||
# REDIS_PASSWORD:
|
92
cookbook/misc/migrate_proxy_config.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
"""
|
||||
LiteLLM Migration Script!
|
||||
|
||||
Takes a config.yaml and calls /model/new
|
||||
|
||||
Inputs:
|
||||
- File path to config.yaml
|
||||
- Proxy base url to your hosted proxy
|
||||
|
||||
Step 1: Reads your config.yaml
|
||||
Step 2: reads `model_list` and loops through all models
|
||||
Step 3: calls `<proxy-base-url>/model/new` for each model
|
||||
"""
|
||||
|
||||
import yaml
|
||||
import requests
|
||||
|
||||
_in_memory_os_variables = {}
|
||||
|
||||
|
||||
def migrate_models(config_file, proxy_base_url):
|
||||
# Step 1: Read the config.yaml file
|
||||
with open(config_file, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Step 2: Read the model_list and loop through all models
|
||||
model_list = config.get("model_list", [])
|
||||
print("model_list: ", model_list)
|
||||
for model in model_list:
|
||||
|
||||
model_name = model.get("model_name")
|
||||
print("\nAdding model: ", model_name)
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
print("api_base on config.yaml: ", api_base)
|
||||
|
||||
litellm_model_name = litellm_params.get("model", "") or ""
|
||||
if "vertex_ai/" in litellm_model_name:
|
||||
print(f"\033[91m\nSkipping Vertex AI model\033[0m", model)
|
||||
continue
|
||||
|
||||
for param, value in litellm_params.items():
|
||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
# check if value is in _in_memory_os_variables
|
||||
if value in _in_memory_os_variables:
|
||||
new_value = _in_memory_os_variables[value]
|
||||
print(
|
||||
"\033[92mAlready entered value for \033[0m",
|
||||
value,
|
||||
"\033[92musing \033[0m",
|
||||
new_value,
|
||||
)
|
||||
else:
|
||||
new_value = input(f"Enter value for {value}: ")
|
||||
_in_memory_os_variables[value] = new_value
|
||||
litellm_params[param] = new_value
|
||||
|
||||
print("\nlitellm_params: ", litellm_params)
|
||||
# Confirm before sending POST request
|
||||
confirm = input(
|
||||
"\033[92mDo you want to send the POST request with the above parameters? (y/n): \033[0m"
|
||||
)
|
||||
if confirm.lower() != "y":
|
||||
print("Aborting POST request.")
|
||||
exit()
|
||||
|
||||
# Step 3: Call <proxy-base-url>/model/new for each model
|
||||
url = f"{proxy_base_url}/model/new"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {master_key}",
|
||||
}
|
||||
data = {"model_name": model_name, "litellm_params": litellm_params}
|
||||
print("POSTING data to proxy url", url)
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code != 200:
|
||||
print(f"Error: {response.status_code} - {response.text}")
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
|
||||
# Print the response for each model
|
||||
print(
|
||||
f"Response for model '{model_name}': Status Code:{response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Usage
|
||||
config_file = "config.yaml"
|
||||
proxy_base_url = "http://0.0.0.0:4000"
|
||||
master_key = "sk-1234"
|
||||
print(f"config_file: {config_file}")
|
||||
print(f"proxy_base_url: {proxy_base_url}")
|
||||
migrate_models(config_file, proxy_base_url)
|
|
@ -15,15 +15,16 @@ spec:
|
|||
containers:
|
||||
- name: litellm-container
|
||||
image: ghcr.io/berriai/litellm:main-latest
|
||||
imagePullPolicy: Always
|
||||
env:
|
||||
- name: AZURE_API_KEY
|
||||
value: "d6f****"
|
||||
- name: AZURE_API_BASE
|
||||
value: "https://openai
|
||||
value: "https://openai"
|
||||
- name: LITELLM_MASTER_KEY
|
||||
value: "sk-1234"
|
||||
- name: DATABASE_URL
|
||||
value: "postgresql://ishaan:*********""
|
||||
value: "postgresql://ishaan*********"
|
||||
args:
|
||||
- "--config"
|
||||
- "/app/proxy_config.yaml" # Update the path to mount the config file
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
version: "3.9"
|
||||
services:
|
||||
litellm:
|
||||
build:
|
||||
context: .
|
||||
args:
|
||||
target: runtime
|
||||
image: ghcr.io/berriai/litellm:main-latest
|
||||
volumes:
|
||||
- ./proxy_server_config.yaml:/app/proxy_server_config.yaml # mount your litellm config.yaml
|
||||
ports:
|
||||
- "4000:4000"
|
||||
environment:
|
||||
- AZURE_API_KEY=sk-123
|
||||
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
||||
volumes:
|
||||
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
|
||||
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
|
||||
command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8" ]
|
||||
|
||||
# ...rest of your docker-compose config if any
|
|
@ -72,7 +72,7 @@ Here's the code for how we format all providers. Let us know how we can improve
|
|||
| Anthropic | `claude-instant-1`, `claude-instant-1.2`, `claude-2` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/anthropic.py#L84)
|
||||
| OpenAI Text Completion | `text-davinci-003`, `text-curie-001`, `text-babbage-001`, `text-ada-001`, `babbage-002`, `davinci-002`, | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/main.py#L442)
|
||||
| Replicate | all model names starting with `replicate/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/replicate.py#L180)
|
||||
| Cohere | `command-nightly`, `command`, `command-light`, `command-medium-beta`, `command-xlarge-beta` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/cohere.py#L115)
|
||||
| Cohere | `command-nightly`, `command`, `command-light`, `command-medium-beta`, `command-xlarge-beta`, `command-r-plus` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/cohere.py#L115)
|
||||
| Huggingface | all model names starting with `huggingface/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/huggingface_restapi.py#L186)
|
||||
| OpenRouter | all model names starting with `openrouter/` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/main.py#L611)
|
||||
| AI21 | `j2-mid`, `j2-light`, `j2-ultra` | [Code](https://github.com/BerriAI/litellm/blob/721564c63999a43f96ee9167d0530759d51f8d45/litellm/llms/ai21.py#L107)
|
||||
|
|
45
docs/my-website/docs/completion/vision.md
Normal file
|
@ -0,0 +1,45 @@
|
|||
# Using Vision Models
|
||||
|
||||
## Quick Start
|
||||
Example passing images to a model
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||
|
||||
# openai call
|
||||
response = completion(
|
||||
model = "gpt-4-vision-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What’s in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
## Checking if a model supports `vision`
|
||||
|
||||
Use `litellm.supports_vision(model="")` -> returns `True` if model supports `vision` and `False` if not
|
||||
|
||||
```python
|
||||
assert litellm.supports_vision(model="gpt-4-vision-preview") == True
|
||||
assert litellm.supports_vision(model="gemini-1.0-pro-visionn") == True
|
||||
assert litellm.supports_vision(model="gpt-3.5-turbo") == False
|
||||
```
|
||||
|
|
@ -339,6 +339,8 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
|||
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
|
||||
| textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` |
|
||||
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
|
||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||
|
||||
## Voyage AI Embedding Models
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Enterprise
|
||||
For companies that need better security, user management and professional support
|
||||
For companies that need SSO, user management and professional support for LiteLLM Proxy
|
||||
|
||||
:::info
|
||||
|
||||
|
@ -8,12 +8,13 @@ For companies that need better security, user management and professional suppor
|
|||
:::
|
||||
|
||||
This covers:
|
||||
- ✅ **Features under the [LiteLLM Commercial License](https://docs.litellm.ai/docs/proxy/enterprise):**
|
||||
- ✅ **Features under the [LiteLLM Commercial License (Content Mod, Custom Tags, etc.)](https://docs.litellm.ai/docs/proxy/enterprise)**
|
||||
- ✅ **Feature Prioritization**
|
||||
- ✅ **Custom Integrations**
|
||||
- ✅ **Professional Support - Dedicated discord + slack**
|
||||
- ✅ **Custom SLAs**
|
||||
- ✅ **Secure access with Single Sign-On**
|
||||
- ✅ [**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui)
|
||||
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
||||
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
|
49
docs/my-website/docs/hosted.md
Normal file
|
@ -0,0 +1,49 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
|
||||
# Hosted LiteLLM Proxy
|
||||
|
||||
LiteLLM maintains the proxy, so you can focus on your core products.
|
||||
|
||||
## [**Get Onboarded**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
This is in alpha. Schedule a call with us, and we'll give you a hosted proxy within 30 minutes.
|
||||
|
||||
[**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
### **Status**: Alpha
|
||||
|
||||
Our proxy is already used in production by customers.
|
||||
|
||||
See our status page for [**live reliability**](https://status.litellm.ai/)
|
||||
|
||||
### **Benefits**
|
||||
- **No Maintenance, No Infra**: We'll maintain the proxy, and spin up any additional infrastructure (e.g.: separate server for spend logs) to make sure you can load balance + track spend across multiple LLM projects.
|
||||
- **Reliable**: Our hosted proxy is tested on 1k requests per second, making it reliable for high load.
|
||||
- **Secure**: LiteLLM is currently undergoing SOC-2 compliance, to make sure your data is as secure as possible.
|
||||
|
||||
### Pricing
|
||||
|
||||
Pricing is based on usage. We can figure out a price that works for your team, on the call.
|
||||
|
||||
[**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
## **Screenshots**
|
||||
|
||||
### 1. Create keys
|
||||
|
||||
<Image img={require('../img/litellm_hosted_ui_create_key.png')} />
|
||||
|
||||
### 2. Add Models
|
||||
|
||||
<Image img={require('../img/litellm_hosted_ui_add_models.png')}/>
|
||||
|
||||
### 3. Track spend
|
||||
|
||||
<Image img={require('../img/litellm_hosted_usage_dashboard.png')} />
|
||||
|
||||
|
||||
### 4. Configure load balancing
|
||||
|
||||
<Image img={require('../img/litellm_hosted_ui_router.png')} />
|
||||
|
||||
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
|
@ -8,6 +8,7 @@ liteLLM supports:
|
|||
|
||||
- [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback)
|
||||
- [Lunary](https://lunary.ai/docs)
|
||||
- [Langfuse](https://langfuse.com/docs)
|
||||
- [Helicone](https://docs.helicone.ai/introduction)
|
||||
- [Traceloop](https://traceloop.com/docs)
|
||||
- [Athina](https://docs.athina.ai/)
|
||||
|
@ -22,8 +23,8 @@ from litellm import completion
|
|||
|
||||
# set callbacks
|
||||
litellm.input_callback=["sentry"] # for sentry breadcrumbing - logs the input being sent to the api
|
||||
litellm.success_callback=["posthog", "helicone", "lunary", "athina"]
|
||||
litellm.failure_callback=["sentry", "lunary"]
|
||||
litellm.success_callback=["posthog", "helicone", "langfuse", "lunary", "athina"]
|
||||
litellm.failure_callback=["sentry", "lunary", "langfuse"]
|
||||
|
||||
## set env variables
|
||||
os.environ['SENTRY_DSN'], os.environ['SENTRY_API_TRACE_RATE']= ""
|
||||
|
@ -32,6 +33,9 @@ os.environ["HELICONE_API_KEY"] = ""
|
|||
os.environ["TRACELOOP_API_KEY"] = ""
|
||||
os.environ["LUNARY_PUBLIC_KEY"] = ""
|
||||
os.environ["ATHINA_API_KEY"] = ""
|
||||
os.environ["LANGFUSE_PUBLIC_KEY"] = ""
|
||||
os.environ["LANGFUSE_SECRET_KEY"] = ""
|
||||
os.environ["LANGFUSE_HOST"] = ""
|
||||
|
||||
response = completion(model="gpt-3.5-turbo", messages=messages)
|
||||
```
|
||||
|
|
68
docs/my-website/docs/observability/greenscale_integration.md
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Greenscale Tutorial
|
||||
|
||||
[Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII).
|
||||
|
||||
## Getting Started
|
||||
|
||||
Use Greenscale to log requests across all LLM Providers
|
||||
|
||||
liteLLM provides `callbacks`, making it easy for you to log data depending on the status of your responses.
|
||||
|
||||
## Using Callbacks
|
||||
|
||||
First, email `hello@greenscale.ai` to get an API_KEY.
|
||||
|
||||
Use just 1 line of code, to instantly log your responses **across all providers** with Greenscale:
|
||||
|
||||
```python
|
||||
litellm.success_callback = ["greenscale"]
|
||||
```
|
||||
|
||||
### Complete code
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
## set env variables
|
||||
os.environ['GREENSCALE_API_KEY'] = 'your-greenscale-api-key'
|
||||
os.environ['GREENSCALE_ENDPOINT'] = 'greenscale-endpoint'
|
||||
os.environ["OPENAI_API_KEY"]= ""
|
||||
|
||||
# set callback
|
||||
litellm.success_callback = ["greenscale"]
|
||||
|
||||
#openai call
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}]
|
||||
metadata={
|
||||
"greenscale_project": "acme-project",
|
||||
"greenscale_application": "acme-application"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Additional information in metadata
|
||||
|
||||
You can send any additional information to Greenscale by using the `metadata` field in completion and `greenscale_` prefix. This can be useful for sending metadata about the request, such as the project and application name, customer_id, enviornment, or any other information you want to track usage. `greenscale_project` and `greenscale_application` are required fields.
|
||||
|
||||
```python
|
||||
#openai call with additional metadata
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hi 👋 - i'm openai"}
|
||||
],
|
||||
metadata={
|
||||
"greenscale_project": "acme-project",
|
||||
"greenscale_application": "acme-application",
|
||||
"greenscale_customer_id": "customer-123"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Support & Talk with Greenscale Team
|
||||
|
||||
- [Schedule Demo 👋](https://calendly.com/nandesh/greenscale)
|
||||
- [Website 💻](https://greenscale.ai)
|
||||
- Our email ✉️ `hello@greenscale.ai`
|
|
@ -57,7 +57,7 @@ os.environ["LANGSMITH_API_KEY"] = ""
|
|||
os.environ['OPENAI_API_KEY']=""
|
||||
|
||||
# set langfuse as a callback, litellm will send the data to langfuse
|
||||
litellm.success_callback = ["langfuse"]
|
||||
litellm.success_callback = ["langsmith"]
|
||||
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
|
|
|
@ -177,11 +177,7 @@ print(response)
|
|||
|
||||
:::info
|
||||
|
||||
Claude returns it's output as an XML Tree. [Here is how we translate it](https://github.com/BerriAI/litellm/blob/49642a5b00a53b1babc1a753426a8afcac85dbbe/litellm/llms/prompt_templates/factory.py#L734).
|
||||
|
||||
You can see the raw response via `response._hidden_params["original_response"]`.
|
||||
|
||||
Claude hallucinates, e.g. returning the list param `value` as `<value>\n<item>apple</item>\n<item>banana</item>\n</value>` or `<value>\n<list>\n<item>apple</item>\n<item>banana</item>\n</list>\n</value>`.
|
||||
LiteLLM now uses Anthropic's 'tool' param 🎉 (v1.34.29+)
|
||||
:::
|
||||
|
||||
```python
|
||||
|
@ -228,6 +224,91 @@ assert isinstance(
|
|||
```
|
||||
|
||||
|
||||
### Parallel Function Calling
|
||||
|
||||
Here's how to pass the result of a function call back to an anthropic model:
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "sk-ant.."
|
||||
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
### 1ST FUNCTION CALL ###
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
|
||||
messages.append(
|
||||
response.choices[0].message.model_dump()
|
||||
) # Add assistant tool invokes
|
||||
tool_result = (
|
||||
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
|
||||
)
|
||||
# Add user submitted tool results in the OpenAI format
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"role": "tool",
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
### 2ND FUNCTION CALL ###
|
||||
# In the second response, Claude should deduce answer from tool results
|
||||
second_response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
print(second_response)
|
||||
except Exception as e:
|
||||
print(f"An error occurred - {str(e)}")
|
||||
```
|
||||
|
||||
s/o @[Shekhar Patnaik](https://www.linkedin.com/in/patnaikshekhar) for requesting this!
|
||||
|
||||
## Usage - Vision
|
||||
|
||||
```python
|
||||
|
|
|
@ -1,55 +1,215 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Azure AI Studio
|
||||
|
||||
## Using Mistral models deployed on Azure AI Studio
|
||||
**Ensure the following:**
|
||||
1. The API Base passed ends in the `/v1/` prefix
|
||||
example:
|
||||
```python
|
||||
api_base = "https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/"
|
||||
```
|
||||
|
||||
### Sample Usage - setting env vars
|
||||
2. The `model` passed is listed in [supported models](#supported-models). You **DO NOT** Need to pass your deployment name to litellm. Example `model=azure/Mistral-large-nmefg`
|
||||
|
||||
Set `MISTRAL_AZURE_API_KEY` and `MISTRAL_AZURE_API_BASE` in your env
|
||||
## Usage
|
||||
|
||||
```shell
|
||||
MISTRAL_AZURE_API_KEY = "zE************""
|
||||
MISTRAL_AZURE_API_BASE = "https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1"
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
response = litellm.completion(
|
||||
model="azure/command-r-plus",
|
||||
api_base="<your-deployment-base>/v1/"
|
||||
api_key="eskk******"
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
## Sample Usage - LiteLLM Proxy
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: mistral
|
||||
litellm_params:
|
||||
model: azure/mistral-large-latest
|
||||
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/v1/
|
||||
api_key: JGbKodRcTp****
|
||||
- model_name: command-r-plus
|
||||
litellm_params:
|
||||
model: azure/command-r-plus
|
||||
api_key: os.environ/AZURE_COHERE_API_KEY
|
||||
api_base: os.environ/AZURE_COHERE_API_BASE
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="mistral",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "mistral",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Function Calling
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
# set env
|
||||
os.environ["AZURE_MISTRAL_API_KEY"] = "your-api-key"
|
||||
os.environ["AZURE_MISTRAL_API_BASE"] = "your-api-base"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||
|
||||
response = completion(
|
||||
model="mistral/Mistral-large-dfgfj",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
model="azure/mistral-large-latest",
|
||||
api_base=os.getenv("AZURE_MISTRAL_API_BASE")
|
||||
api_key=os.getenv("AZURE_MISTRAL_API_KEY")
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
```
|
||||
|
||||
### Sample Usage - passing `api_base` and `api_key` to `litellm.completion`
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
response = completion(
|
||||
model="mistral/Mistral-large-dfgfj",
|
||||
api_base="https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com",
|
||||
api_key = "JGbKodRcTp****"
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
print(response)
|
||||
|
||||
```
|
||||
|
||||
### [LiteLLM Proxy] Using Mistral Models
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $YOUR_API_KEY" \
|
||||
-d '{
|
||||
"model": "mistral",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What'\''s the weather like in Boston today?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": "auto"
|
||||
}'
|
||||
|
||||
Set this on your litellm proxy config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: mistral
|
||||
litellm_params:
|
||||
model: mistral/Mistral-large-dfgfj
|
||||
api_base: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com
|
||||
api_key: JGbKodRcTp****
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
|
||||
| Cohere ommand-r | `completion(model="azure/command-r", messages)` |
|
||||
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
|
||||
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ for chunk in response:
|
|||
|------------|----------------|
|
||||
| command-r | `completion('command-r', messages)` |
|
||||
| command-light | `completion('command-light', messages)` |
|
||||
| command-r-plus | `completion('command-r-plus', messages)` |
|
||||
| command-medium | `completion('command-medium', messages)` |
|
||||
| command-medium-beta | `completion('command-medium-beta', messages)` |
|
||||
| command-xlarge-nightly | `completion('command-xlarge-nightly', messages)` |
|
||||
|
|
|
@ -23,7 +23,7 @@ In certain use-cases you may need to make calls to the models and pass [safety s
|
|||
```python
|
||||
response = completion(
|
||||
model="gemini/gemini-pro",
|
||||
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
||||
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}],
|
||||
safety_settings=[
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
|
@ -95,9 +95,8 @@ print(content)
|
|||
```
|
||||
|
||||
## Chat Models
|
||||
| Model Name | Function Call | Required OS Variables |
|
||||
|------------------|--------------------------------------|-------------------------|
|
||||
| gemini-pro | `completion('gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-1.5-pro | `completion('gemini/gemini-1.5-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-pro-vision | `completion('gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-1.5-pro-vision | `completion('gemini/gemini-1.5-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| Model Name | Function Call | Required OS Variables |
|
||||
|-----------------------|--------------------------------------------------------|--------------------------------|
|
||||
| gemini-pro | `completion('gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-1.5-pro-latest | `completion('gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
| gemini-pro-vision | `completion('gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||
|
|
|
@ -48,6 +48,109 @@ We support ALL Groq models, just set `groq/` as a prefix when sending completion
|
|||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| llama3-8b-8192 | `completion(model="groq/llama3-8b-8192", messages)` |
|
||||
| llama3-70b-8192 | `completion(model="groq/llama3-70b-8192", messages)` |
|
||||
| llama2-70b-4096 | `completion(model="groq/llama2-70b-4096", messages)` |
|
||||
| mixtral-8x7b-32768 | `completion(model="groq/mixtral-8x7b-32768", messages)` |
|
||||
| gemma-7b-it | `completion(model="groq/gemma-7b-it", messages)` |
|
||||
|
||||
## Groq - Tool / Function Calling Example
|
||||
|
||||
```python
|
||||
# Example dummy function hard coded to return the current weather
|
||||
import json
|
||||
def get_current_weather(location, unit="fahrenheit"):
|
||||
"""Get the current weather in a given location"""
|
||||
if "tokyo" in location.lower():
|
||||
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
|
||||
elif "san francisco" in location.lower():
|
||||
return json.dumps(
|
||||
{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}
|
||||
)
|
||||
elif "paris" in location.lower():
|
||||
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
|
||||
else:
|
||||
return json.dumps({"location": location, "temperature": "unknown"})
|
||||
|
||||
|
||||
|
||||
|
||||
# Step 1: send the conversation and available functions to the model
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a function calling LLM that uses the data extracted from get_current_weather to answer questions about the weather in San Francisco.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in San Francisco?",
|
||||
},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
response = litellm.completion(
|
||||
model="groq/llama2-70b-4096",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto", # auto is default, but we'll be explicit
|
||||
)
|
||||
print("Response\n", response)
|
||||
response_message = response.choices[0].message
|
||||
tool_calls = response_message.tool_calls
|
||||
|
||||
|
||||
# Step 2: check if the model wanted to call a function
|
||||
if tool_calls:
|
||||
# Step 3: call the function
|
||||
# Note: the JSON response may not always be valid; be sure to handle errors
|
||||
available_functions = {
|
||||
"get_current_weather": get_current_weather,
|
||||
}
|
||||
messages.append(
|
||||
response_message
|
||||
) # extend conversation with assistant's reply
|
||||
print("Response message\n", response_message)
|
||||
# Step 4: send the info for each function call and function response to the model
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
function_response = function_to_call(
|
||||
location=function_args.get("location"),
|
||||
unit=function_args.get("unit"),
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
print(f"messages: {messages}")
|
||||
second_response = litellm.completion(
|
||||
model="groq/llama2-70b-4096", messages=messages
|
||||
) # get a new response from the model where it can see the function response
|
||||
print("second response\n", second_response)
|
||||
```
|
|
@ -50,8 +50,53 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported.
|
|||
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
|
||||
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
|
||||
| mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` |
|
||||
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
||||
|
||||
|
||||
## Function Calling
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
# set env
|
||||
os.environ["MISTRAL_API_KEY"] = "your-api-key"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||
|
||||
response = completion(
|
||||
model="mistral/mistral-large-latest",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
```
|
||||
|
||||
## Sample Usage - Embedding
|
||||
```python
|
||||
from litellm import embedding
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Ollama
|
||||
LiteLLM supports all models from [Ollama](https://github.com/jmorganca/ollama)
|
||||
LiteLLM supports all models from [Ollama](https://github.com/ollama/ollama)
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_Ollama.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
|
@ -97,7 +97,7 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
## Ollama Models
|
||||
Ollama supported models: https://github.com/jmorganca/ollama
|
||||
Ollama supported models: https://github.com/ollama/ollama
|
||||
|
||||
| Model Name | Function Call |
|
||||
|----------------------|-----------------------------------------------------------------------------------
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# OpenAI
|
||||
LiteLLM supports OpenAI Chat + Text completion and embedding calls.
|
||||
LiteLLM supports OpenAI Chat + Embedding calls.
|
||||
|
||||
### Required API Keys
|
||||
|
||||
|
@ -22,6 +25,132 @@ response = completion(
|
|||
)
|
||||
```
|
||||
|
||||
### Usage - LiteLLM Proxy Server
|
||||
|
||||
Here's how to call OpenAI models with the LiteLLM Proxy Server
|
||||
|
||||
### 1. Save key in your environment
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY=""
|
||||
```
|
||||
|
||||
### 2. Start the proxy
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="config" label="config.yaml">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo # The `openai/` prefix will call openai.chat.completions.create
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: gpt-3.5-turbo-instruct
|
||||
litellm_params:
|
||||
model: text-completion-openai/gpt-3.5-turbo-instruct # The `text-completion-openai/` prefix will call openai.completions.create
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="config-*" label="config.yaml - proxy all OpenAI models">
|
||||
|
||||
Use this to add all openai models with one API Key. **WARNING: This will not do any load balancing**
|
||||
This means requests to `gpt-4`, `gpt-3.5-turbo` , `gpt-4-turbo-preview` will all go through this route
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "*" # all requests where model not in your config go to this deployment
|
||||
litellm_params:
|
||||
model: openai/* # set `openai/` to use the openai route
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="cli" label="CLI">
|
||||
|
||||
```bash
|
||||
$ litellm --model gpt-3.5-turbo
|
||||
|
||||
# Server running on http://0.0.0.0:4000
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
### 3. Test it
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="Curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
]
|
||||
}
|
||||
'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai" label="OpenAI v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="langchain" label="Langchain">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
|
||||
model = "gpt-3.5-turbo",
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
response = chat(messages)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Optional Keys - OpenAI Organization, OpenAI API Base
|
||||
|
||||
```python
|
||||
|
@ -34,6 +163,8 @@ os.environ["OPENAI_API_BASE"] = "openaiai-api-base" # OPTIONAL
|
|||
|
||||
| Model Name | Function Call |
|
||||
|-----------------------|-----------------------------------------------------------------|
|
||||
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
|
||||
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
|
||||
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
|
||||
| gpt-4-1106-preview | `response = completion(model="gpt-4-1106-preview", messages=messages)` |
|
||||
| gpt-3.5-turbo-1106 | `response = completion(model="gpt-3.5-turbo-1106", messages=messages)` |
|
||||
|
@ -55,6 +186,7 @@ These also support the `OPENAI_API_BASE` environment variable, which can be used
|
|||
## OpenAI Vision Models
|
||||
| Model Name | Function Call |
|
||||
|-----------------------|-----------------------------------------------------------------|
|
||||
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
|
||||
| gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` |
|
||||
|
||||
#### Usage
|
||||
|
@ -88,19 +220,6 @@ response = completion(
|
|||
|
||||
```
|
||||
|
||||
## OpenAI Text Completion Models / Instruct Models
|
||||
|
||||
| Model Name | Function Call |
|
||||
|---------------------|----------------------------------------------------|
|
||||
| gpt-3.5-turbo-instruct | `response = completion(model="gpt-3.5-turbo-instruct", messages=messages)` |
|
||||
| gpt-3.5-turbo-instruct-0914 | `response = completion(model="gpt-3.5-turbo-instruct-091", messages=messages)` |
|
||||
| text-davinci-003 | `response = completion(model="text-davinci-003", messages=messages)` |
|
||||
| ada-001 | `response = completion(model="ada-001", messages=messages)` |
|
||||
| curie-001 | `response = completion(model="curie-001", messages=messages)` |
|
||||
| babbage-001 | `response = completion(model="babbage-001", messages=messages)` |
|
||||
| babbage-002 | `response = completion(model="babbage-002", messages=messages)` |
|
||||
| davinci-002 | `response = completion(model="davinci-002", messages=messages)` |
|
||||
|
||||
## Advanced
|
||||
|
||||
### Parallel Function calling
|
||||
|
|
|
@ -5,7 +5,9 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
To call models hosted behind an openai proxy, make 2 changes:
|
||||
|
||||
1. Put `openai/` in front of your model name, so litellm knows you're trying to call an openai-compatible endpoint.
|
||||
1. For `/chat/completions`: Put `openai/` in front of your model name, so litellm knows you're trying to call an openai `/chat/completions` endpoint.
|
||||
|
||||
2. For `/completions`: Put `text-completion-openai/` in front of your model name, so litellm knows you're trying to call an openai `/completions` endpoint.
|
||||
|
||||
2. **Do NOT** add anything additional to the base url e.g. `/v1/embedding`. LiteLLM uses the openai-client to make these calls, and that automatically adds the relevant endpoints.
|
||||
|
||||
|
|
|
@ -1,7 +1,16 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Replicate
|
||||
|
||||
LiteLLM supports all models on Replicate
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
### API KEYS
|
||||
```python
|
||||
import os
|
||||
|
@ -16,14 +25,175 @@ import os
|
|||
## set ENV variables
|
||||
os.environ["REPLICATE_API_KEY"] = "replicate key"
|
||||
|
||||
# replicate llama-2 call
|
||||
# replicate llama-3 call
|
||||
response = completion(
|
||||
model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
|
||||
model="replicate/meta/meta-llama-3-8b-instruct",
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||
)
|
||||
```
|
||||
|
||||
### Example - Calling Replicate Deployments
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: llama-3
|
||||
litellm_params:
|
||||
model: replicate/meta/meta-llama-3-8b-instruct
|
||||
api_key: os.environ/REPLICATE_API_KEY
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="llama-3",
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Be a good human!"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What do you know about earth?"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "llama-3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Be a good human!"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What do you know about earth?"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Expected Replicate Call
|
||||
|
||||
This is the call litellm will make to replicate, from the above example:
|
||||
|
||||
```bash
|
||||
|
||||
POST Request Sent from LiteLLM:
|
||||
curl -X POST \
|
||||
https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct \
|
||||
-H 'Authorization: Token your-api-key' -H 'Content-Type: application/json' \
|
||||
-d '{'version': 'meta/meta-llama-3-8b-instruct', 'input': {'prompt': '<|start_header_id|>system<|end_header_id|>\n\nBe a good human!<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat do you know about earth?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'}}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
## Advanced Usage - Prompt Formatting
|
||||
|
||||
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
|
||||
|
||||
To apply a custom prompt template:
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
import os
|
||||
os.environ["REPLICATE_API_KEY"] = ""
|
||||
|
||||
# Create your own custom prompt template
|
||||
litellm.register_prompt_template(
|
||||
model="togethercomputer/LLaMA-2-7B-32K",
|
||||
initial_prompt_value="You are a good assistant" # [OPTIONAL]
|
||||
roles={
|
||||
"system": {
|
||||
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
|
||||
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "[INST] ", # [OPTIONAL]
|
||||
"post_message": " [/INST]" # [OPTIONAL]
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "\n" # [OPTIONAL]
|
||||
"post_message": "\n" # [OPTIONAL]
|
||||
}
|
||||
}
|
||||
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
|
||||
)
|
||||
|
||||
def test_replicate_custom_model():
|
||||
model = "replicate/togethercomputer/LLaMA-2-7B-32K"
|
||||
response = completion(model=model, messages=messages)
|
||||
print(response['choices'][0]['message']['content'])
|
||||
return response
|
||||
|
||||
test_replicate_custom_model()
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```yaml
|
||||
# Model-specific parameters
|
||||
model_list:
|
||||
- model_name: mistral-7b # model alias
|
||||
litellm_params: # actual params for litellm.completion()
|
||||
model: "replicate/mistralai/Mistral-7B-Instruct-v0.1"
|
||||
api_key: os.environ/REPLICATE_API_KEY
|
||||
initial_prompt_value: "\n"
|
||||
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
||||
final_prompt_value: "\n"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
max_tokens: 4096
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
## Advanced Usage - Calling Replicate Deployments
|
||||
Calling a [deployed replicate LLM](https://replicate.com/deployments)
|
||||
Add the `replicate/deployments/` prefix to your model, so litellm will call the `deployments` endpoint. This will call `ishaan-jaff/ishaan-mistral` deployment on replicate
|
||||
|
||||
|
@ -40,7 +210,7 @@ Replicate responses can take 3-5 mins due to replicate cold boots, if you're try
|
|||
|
||||
:::
|
||||
|
||||
### Replicate Models
|
||||
## Replicate Models
|
||||
liteLLM supports all replicate LLMs
|
||||
|
||||
For replicate models ensure to add a `replicate/` prefix to the `model` arg. liteLLM detects it using this arg.
|
||||
|
@ -49,15 +219,15 @@ Below are examples on how to call replicate LLMs using liteLLM
|
|||
|
||||
Model Name | Function Call | Required OS Variables |
|
||||
-----------------------------|----------------------------------------------------------------|--------------------------------------|
|
||||
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` |
|
||||
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` |
|
||||
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
replicate deployment | `completion(model='replicate/deployments/ishaan-jaff/ishaan-mistral', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
|
||||
|
||||
### Passing additional params - max_tokens, temperature
|
||||
## Passing additional params - max_tokens, temperature
|
||||
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
|
||||
|
||||
```python
|
||||
|
@ -73,11 +243,22 @@ response = completion(
|
|||
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||
max_tokens=20,
|
||||
temperature=0.5
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
### Passings Replicate specific params
|
||||
**proxy**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: llama-3
|
||||
litellm_params:
|
||||
model: replicate/meta/meta-llama-3-8b-instruct
|
||||
api_key: os.environ/REPLICATE_API_KEY
|
||||
max_tokens: 20
|
||||
temperature: 0.5
|
||||
```
|
||||
|
||||
## Passings Replicate specific params
|
||||
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Replicate by passing them to `litellm.completion`
|
||||
|
||||
Example `seed`, `min_tokens` are Replicate specific param
|
||||
|
@ -98,3 +279,15 @@ response = completion(
|
|||
top_k=20,
|
||||
)
|
||||
```
|
||||
|
||||
**proxy**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: llama-3
|
||||
litellm_params:
|
||||
model: replicate/meta/meta-llama-3-8b-instruct
|
||||
api_key: os.environ/REPLICATE_API_KEY
|
||||
min_tokens: 2
|
||||
top_k: 20
|
||||
```
|
||||
|
|
163
docs/my-website/docs/providers/text_completion_openai.md
Normal file
|
@ -0,0 +1,163 @@
|
|||
# OpenAI (Text Completion)
|
||||
|
||||
LiteLLM supports OpenAI text completion models
|
||||
|
||||
### Required API Keys
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||
```
|
||||
|
||||
### Usage
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||
|
||||
# openai call
|
||||
response = completion(
|
||||
model = "gpt-3.5-turbo-instruct",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}]
|
||||
)
|
||||
```
|
||||
|
||||
### Usage - LiteLLM Proxy Server
|
||||
|
||||
Here's how to call OpenAI models with the LiteLLM Proxy Server
|
||||
|
||||
### 1. Save key in your environment
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY=""
|
||||
```
|
||||
|
||||
### 2. Start the proxy
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="config" label="config.yaml">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo # The `openai/` prefix will call openai.chat.completions.create
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: gpt-3.5-turbo-instruct
|
||||
litellm_params:
|
||||
model: text-completion-openai/gpt-3.5-turbo-instruct # The `text-completion-openai/` prefix will call openai.completions.create
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="config-*" label="config.yaml - proxy all OpenAI models">
|
||||
|
||||
Use this to add all openai models with one API Key. **WARNING: This will not do any load balancing**
|
||||
This means requests to `gpt-4`, `gpt-3.5-turbo` , `gpt-4-turbo-preview` will all go through this route
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "*" # all requests where model not in your config go to this deployment
|
||||
litellm_params:
|
||||
model: openai/* # set `openai/` to use the openai route
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="cli" label="CLI">
|
||||
|
||||
```bash
|
||||
$ litellm --model gpt-3.5-turbo-instruct
|
||||
|
||||
# Server running on http://0.0.0.0:4000
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
### 3. Test it
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="Curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
]
|
||||
}
|
||||
'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai" label="OpenAI v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo-instruct", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="langchain" label="Langchain">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
|
||||
model = "gpt-3.5-turbo-instruct",
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
response = chat(messages)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## OpenAI Text Completion Models / Instruct Models
|
||||
|
||||
| Model Name | Function Call |
|
||||
|---------------------|----------------------------------------------------|
|
||||
| gpt-3.5-turbo-instruct | `response = completion(model="gpt-3.5-turbo-instruct", messages=messages)` |
|
||||
| gpt-3.5-turbo-instruct-0914 | `response = completion(model="gpt-3.5-turbo-instruct-0914", messages=messages)` |
|
||||
| text-davinci-003 | `response = completion(model="text-davinci-003", messages=messages)` |
|
||||
| ada-001 | `response = completion(model="ada-001", messages=messages)` |
|
||||
| curie-001 | `response = completion(model="curie-001", messages=messages)` |
|
||||
| babbage-001 | `response = completion(model="babbage-001", messages=messages)` |
|
||||
| babbage-002 | `response = completion(model="babbage-002", messages=messages)` |
|
||||
| davinci-002 | `response = completion(model="davinci-002", messages=messages)` |
|
|
@ -1,18 +1,25 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# VertexAI - Google [Gemini, Model Garden]
|
||||
# VertexAI [Anthropic, Gemini, Model Garden]
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
## Pre-requisites
|
||||
* `pip install google-cloud-aiplatform`
|
||||
* `pip install google-cloud-aiplatform` (pre-installed on proxy docker image)
|
||||
* Authentication:
|
||||
* run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc)
|
||||
* Alternatively you can set `application_default_credentials.json`
|
||||
* Alternatively you can set `GOOGLE_APPLICATION_CREDENTIALS`
|
||||
|
||||
Here's how: [**Jump to Code**](#extra)
|
||||
|
||||
- Create a service account on GCP
|
||||
- Export the credentials as a json
|
||||
- load the json and json.dump the json as a string
|
||||
- store the json string in your environment as `GOOGLE_APPLICATION_CREDENTIALS`
|
||||
|
||||
## Sample Usage
|
||||
```python
|
||||
|
@ -123,6 +130,100 @@ Here's how to use Vertex AI with the LiteLLM Proxy Server
|
|||
|
||||
</Tabs>
|
||||
|
||||
## Specifying Safety Settings
|
||||
In certain use-cases you may need to make calls to the models and pass [safety settigns](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example:
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
response = completion(
|
||||
model="gemini/gemini-pro",
|
||||
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
|
||||
safety_settings=[
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
]
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
**Option 1: Set in config**
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gemini-experimental
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-experimental
|
||||
vertex_project: litellm-epic
|
||||
vertex_location: us-central1
|
||||
safety_settings:
|
||||
- category: HARM_CATEGORY_HARASSMENT
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_HATE_SPEECH
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
|
||||
threshold: BLOCK_NONE
|
||||
- category: HARM_CATEGORY_DANGEROUS_CONTENT
|
||||
threshold: BLOCK_NONE
|
||||
```
|
||||
|
||||
**Option 2: Set on call**
|
||||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model="gemini-experimental",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you write exploits?",
|
||||
}
|
||||
],
|
||||
max_tokens=8192,
|
||||
stream=False,
|
||||
temperature=0.0,
|
||||
|
||||
extra_body={
|
||||
"safety_settings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Set Vertex Project & Vertex Location
|
||||
All calls using Vertex AI require the following parameters:
|
||||
* Your Project ID
|
||||
|
@ -149,6 +250,85 @@ os.environ["VERTEXAI_LOCATION"] = "us-central1 # Your Location
|
|||
# set directly on module
|
||||
litellm.vertex_location = "us-central1 # Your Location
|
||||
```
|
||||
## Anthropic
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| claude-3-opus@20240229 | `completion('vertex_ai/claude-3-opus@20240229', messages)` |
|
||||
| claude-3-sonnet@20240229 | `completion('vertex_ai/claude-3-sonnet@20240229', messages)` |
|
||||
| claude-3-haiku@20240307 | `completion('vertex_ai/claude-3-haiku@20240307', messages)` |
|
||||
|
||||
### Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""
|
||||
|
||||
model = "claude-3-sonnet@20240229"
|
||||
|
||||
vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"]
|
||||
vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"]
|
||||
|
||||
response = completion(
|
||||
model="vertex_ai/" + model,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.7,
|
||||
vertex_ai_project=vertex_ai_project,
|
||||
vertex_ai_location=vertex_ai_location,
|
||||
)
|
||||
print("\nModel Response", response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
**1. Add to config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: anthropic-vertex
|
||||
litellm_params:
|
||||
model: vertex_ai/claude-3-sonnet@20240229
|
||||
vertex_ai_project: "my-test-project"
|
||||
vertex_ai_location: "us-east-1"
|
||||
- model_name: anthropic-vertex
|
||||
litellm_params:
|
||||
model: vertex_ai/claude-3-sonnet@20240229
|
||||
vertex_ai_project: "my-test-project"
|
||||
vertex_ai_location: "us-west-1"
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING at http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "anthropic-vertex", # 👈 the 'model_name' in config
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Model Garden
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
|
@ -175,18 +355,15 @@ response = completion(
|
|||
|------------------|--------------------------------------|
|
||||
| gemini-pro | `completion('gemini-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
|
||||
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
|
||||
|
||||
## Gemini Pro Vision
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| gemini-pro-vision | `completion('gemini-pro-vision', messages)`, `completion('vertex_ai/gemini-pro-vision', messages)`|
|
||||
|
||||
## Gemini 1.5 Pro (and Vision)
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| gemini-1.5-pro-vision | `completion('gemini-pro-vision', messages)`, `completion('vertex_ai/gemini-pro-vision', messages)`|
|
||||
| gemini-1.5-pro | `completion('gemini-1.5-pro', messages)`, `completion('vertex_ai/gemini-pro', messages)` |
|
||||
|
||||
|
||||
|
||||
|
@ -298,3 +475,75 @@ print(response)
|
|||
| code-bison@001 | `completion('code-bison@001', messages)` |
|
||||
| code-gecko@001 | `completion('code-gecko@001', messages)` |
|
||||
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
||||
|
||||
|
||||
## Extra
|
||||
|
||||
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
||||
Here's the code for storing your service account credentials as `GOOGLE_APPLICATION_CREDENTIALS` environment variable:
|
||||
|
||||
|
||||
```python
|
||||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + "/vertex_key.json"
|
||||
|
||||
# Read the existing content of the file or create an empty dictionary
|
||||
try:
|
||||
with open(vertex_key_path, "r") as file:
|
||||
# Read the file content
|
||||
print("Read vertexai file path")
|
||||
content = file.read()
|
||||
|
||||
# If the file is empty or not valid JSON, create an empty dictionary
|
||||
if not content or not content.strip():
|
||||
service_account_key_data = {}
|
||||
else:
|
||||
# Attempt to load the existing JSON content
|
||||
file.seek(0)
|
||||
service_account_key_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
# If the file doesn't exist, create an empty dictionary
|
||||
service_account_key_data = {}
|
||||
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
||||
# Write the updated content to the temporary file
|
||||
json.dump(service_account_key_data, temp_file, indent=2)
|
||||
|
||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||
```
|
||||
|
||||
|
||||
### Using GCP Service Account
|
||||
|
||||
1. Figure out the Service Account bound to the Google Cloud Run service
|
||||
|
||||
<Image img={require('../../img/gcp_acc_1.png')} />
|
||||
|
||||
2. Get the FULL EMAIL address of the corresponding Service Account
|
||||
|
||||
3. Next, go to IAM & Admin > Manage Resources , select your top-level project that houses your Google Cloud Run Service
|
||||
|
||||
Click `Add Principal`
|
||||
|
||||
<Image img={require('../../img/gcp_acc_2.png')}/>
|
||||
|
||||
4. Specify the Service Account as the principal and Vertex AI User as the role
|
||||
|
||||
<Image img={require('../../img/gcp_acc_3.png')}/>
|
||||
|
||||
Once that's done, when you deploy the new container in the Google Cloud Run service, LiteLLM will have automatic access to all Vertex AI endpoints.
|
||||
|
||||
|
||||
s/o @[Darien Kindlund](https://www.linkedin.com/in/kindlund/) for this tutorial
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -25,8 +25,11 @@ All models listed here https://docs.voyageai.com/embeddings/#models-and-specific
|
|||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| voyage-2 | `embedding(model="voyage/voyage-2", input)` |
|
||||
| voyage-large-2 | `embedding(model="voyage/voyage-large-2", input)` |
|
||||
| voyage-law-2 | `embedding(model="voyage/voyage-law-2", input)` |
|
||||
| voyage-code-2 | `embedding(model="voyage/voyage-code-2", input)` |
|
||||
| voyage-lite-02-instruct | `embedding(model="voyage/voyage-lite-02-instruct", input)` |
|
||||
| voyage-01 | `embedding(model="voyage/voyage-01", input)` |
|
||||
| voyage-lite-01 | `embedding(model="voyage/voyage-lite-01", input)` |
|
||||
| voyage-lite-01-instruct | `embedding(model="voyage/voyage-lite-01-instruct", input)` |
|
||||
|
||||
|
||||
|
|
|
@ -61,6 +61,22 @@ litellm_settings:
|
|||
ttl: 600 # will be cached on redis for 600s
|
||||
```
|
||||
|
||||
|
||||
## SSL
|
||||
|
||||
just set `REDIS_SSL="True"` in your .env, and LiteLLM will pick this up.
|
||||
|
||||
```env
|
||||
REDIS_SSL="True"
|
||||
```
|
||||
|
||||
For quick testing, you can also use REDIS_URL, eg.:
|
||||
|
||||
```
|
||||
REDIS_URL="rediss://.."
|
||||
```
|
||||
|
||||
but we **don't** recommend using REDIS_URL in prod. We've noticed a performance difference between using it vs. redis_host, port, etc.
|
||||
#### Step 2: Add Redis Credentials to .env
|
||||
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
|
||||
|
||||
|
@ -265,32 +281,6 @@ litellm_settings:
|
|||
supported_call_types: ["acompletion", "completion", "embedding", "aembedding"] # defaults to all litellm call types
|
||||
```
|
||||
|
||||
|
||||
### Turn on `batch_redis_requests`
|
||||
|
||||
**What it does?**
|
||||
When a request is made:
|
||||
|
||||
- Check if a key starting with `litellm:<hashed_api_key>:<call_type>:` exists in-memory, if no - get the last 100 cached requests for this key and store it
|
||||
|
||||
- New requests are stored with this `litellm:..` as the namespace
|
||||
|
||||
**Why?**
|
||||
Reduce number of redis GET requests. This improved latency by 46% in prod load tests.
|
||||
|
||||
**Usage**
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
cache: true
|
||||
cache_params:
|
||||
type: redis
|
||||
... # remaining redis args (host, port, etc.)
|
||||
callbacks: ["batch_redis_requests"] # 👈 KEY CHANGE!
|
||||
```
|
||||
|
||||
[**SEE CODE**](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/batch_redis_get.py)
|
||||
|
||||
### Turn on / off caching per request.
|
||||
|
||||
The proxy support 3 cache-controls:
|
||||
|
@ -384,6 +374,87 @@ chat_completion = client.chat.completions.create(
|
|||
)
|
||||
```
|
||||
|
||||
### Deleting Cache Keys - `/cache/delete`
|
||||
In order to delete a cache key, send a request to `/cache/delete` with the `keys` you want to delete
|
||||
|
||||
Example
|
||||
```shell
|
||||
curl -X POST "http://0.0.0.0:4000/cache/delete" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"keys": ["586bf3f3c1bf5aecb55bd9996494d3bbc69eb58397163add6d49537762a7548d", "key2"]}'
|
||||
```
|
||||
|
||||
```shell
|
||||
# {"status":"success"}
|
||||
```
|
||||
|
||||
#### Viewing Cache Keys from responses
|
||||
You can view the cache_key in the response headers, on cache hits the cache key is sent as the `x-litellm-cache-key` response headers
|
||||
```shell
|
||||
curl -i --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"user": "ishan",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is litellm"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
Response from litellm proxy
|
||||
```json
|
||||
date: Thu, 04 Apr 2024 17:37:21 GMT
|
||||
content-type: application/json
|
||||
x-litellm-cache-key: 586bf3f3c1bf5aecb55bd9996494d3bbc69eb58397163add6d49537762a7548d
|
||||
|
||||
{
|
||||
"id": "chatcmpl-9ALJTzsBlXR9zTxPvzfFFtFbFtG6T",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "I'm sorr.."
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1712252235,
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
### Turn on `batch_redis_requests`
|
||||
|
||||
**What it does?**
|
||||
When a request is made:
|
||||
|
||||
- Check if a key starting with `litellm:<hashed_api_key>:<call_type>:` exists in-memory, if no - get the last 100 cached requests for this key and store it
|
||||
|
||||
- New requests are stored with this `litellm:..` as the namespace
|
||||
|
||||
**Why?**
|
||||
Reduce number of redis GET requests. This improved latency by 46% in prod load tests.
|
||||
|
||||
**Usage**
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
cache: true
|
||||
cache_params:
|
||||
type: redis
|
||||
... # remaining redis args (host, port, etc.)
|
||||
callbacks: ["batch_redis_requests"] # 👈 KEY CHANGE!
|
||||
```
|
||||
|
||||
[**SEE CODE**](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/batch_redis_get.py)
|
||||
|
||||
## Supported `cache_params` on proxy config.yaml
|
||||
|
||||
```yaml
|
||||
|
|
|
@ -600,6 +600,7 @@ general_settings:
|
|||
"general_settings": {
|
||||
"completion_model": "string",
|
||||
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
||||
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
|
||||
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
||||
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
||||
"enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param
|
||||
|
|
9
docs/my-website/docs/proxy/demo.md
Normal file
|
@ -0,0 +1,9 @@
|
|||
# 🎉 Demo App
|
||||
|
||||
Here is a demo of the proxy. To log in pass in:
|
||||
|
||||
- Username: admin
|
||||
- Password: sk-1234
|
||||
|
||||
|
||||
[Demo UI](https://demo.litellm.ai/ui)
|
|
@ -231,13 +231,16 @@ Your OpenAI proxy server is now running on `http://127.0.0.1:4000`.
|
|||
| Docs | When to Use |
|
||||
| --- | --- |
|
||||
| [Quick Start](#quick-start) | call 100+ LLMs + Load Balancing |
|
||||
| [Deploy with Database](#deploy-with-database) | + use Virtual Keys + Track Spend |
|
||||
| [Deploy with Database](#deploy-with-database) | + use Virtual Keys + Track Spend (Note: When deploying with a database providing a `DATABASE_URL` and `LITELLM_MASTER_KEY` are required in your env ) |
|
||||
| [LiteLLM container + Redis](#litellm-container--redis) | + load balance across multiple litellm containers |
|
||||
| [LiteLLM Database container + PostgresDB + Redis](#litellm-database-container--postgresdb--redis) | + use Virtual Keys + Track Spend + load balance across multiple litellm containers |
|
||||
|
||||
## Deploy with Database
|
||||
### Docker, Kubernetes, Helm Chart
|
||||
|
||||
Requirements:
|
||||
- Need a postgres database (e.g. [Supabase](https://supabase.com/), [Neon](https://neon.tech/), etc) Set `DATABASE_URL=postgresql://<user>:<password>@<host>:<port>/<dbname>` in your env
|
||||
- Set a `LITELLM_MASTER_KEY`, this is your Proxy Admin key - you can use this to create other keys (🚨 must start with `sk-`)
|
||||
|
||||
<Tabs>
|
||||
|
||||
|
@ -246,12 +249,14 @@ Your OpenAI proxy server is now running on `http://127.0.0.1:4000`.
|
|||
We maintain a [seperate Dockerfile](https://github.com/BerriAI/litellm/pkgs/container/litellm-database) for reducing build time when running LiteLLM proxy with a connected Postgres Database
|
||||
|
||||
```shell
|
||||
docker pull docker pull ghcr.io/berriai/litellm-database:main-latest
|
||||
docker pull ghcr.io/berriai/litellm-database:main-latest
|
||||
```
|
||||
|
||||
```shell
|
||||
docker run \
|
||||
-v $(pwd)/litellm_config.yaml:/app/config.yaml \
|
||||
-e LITELLM_MASTER_KEY=sk-1234 \
|
||||
-e DATABASE_URL=postgresql://<user>:<password>@<host>:<port>/<dbname> \
|
||||
-e AZURE_API_KEY=d6*********** \
|
||||
-e AZURE_API_BASE=https://openai-***********/ \
|
||||
-p 4000:4000 \
|
||||
|
@ -666,8 +671,8 @@ services:
|
|||
litellm:
|
||||
build:
|
||||
context: .
|
||||
args:
|
||||
target: runtime
|
||||
args:
|
||||
target: runtime
|
||||
image: ghcr.io/berriai/litellm:main-latest
|
||||
ports:
|
||||
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# ✨ Enterprise Features - Content Mod
|
||||
# ✨ Enterprise Features - Content Mod, SSO
|
||||
|
||||
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
|
||||
|
||||
|
@ -12,16 +12,18 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se
|
|||
:::
|
||||
|
||||
Features:
|
||||
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
|
||||
- ✅ Content Moderation with LLM Guard
|
||||
- ✅ Content Moderation with LlamaGuard
|
||||
- ✅ Content Moderation with Google Text Moderations
|
||||
- ✅ Reject calls from Blocked User list
|
||||
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||
- ✅ Don't log/store specific requests (eg confidential LLM requests)
|
||||
- ✅ Don't log/store specific requests to Langfuse, Sentry, etc. (eg confidential LLM requests)
|
||||
- ✅ Tracking Spend for Custom Tags
|
||||
|
||||
|
||||
|
||||
|
||||
## Content Moderation
|
||||
### Content Moderation with LLM Guard
|
||||
|
||||
|
@ -74,7 +76,7 @@ curl --location 'http://localhost:4000/key/generate' \
|
|||
# Returns {..'key': 'my-new-key'}
|
||||
```
|
||||
|
||||
**2. Test it!**
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
|
@ -87,6 +89,76 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
|||
}'
|
||||
```
|
||||
|
||||
#### Turn on/off per request
|
||||
|
||||
**1. Update config**
|
||||
```yaml
|
||||
litellm_settings:
|
||||
callbacks: ["llmguard_moderations"]
|
||||
llm_guard_mode: "request-specific"
|
||||
```
|
||||
|
||||
**2. Create new key**
|
||||
|
||||
```bash
|
||||
curl --location 'http://localhost:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"models": ["fake-openai-endpoint"],
|
||||
}'
|
||||
|
||||
# Returns {..'key': 'my-new-key'}
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-1234",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
],
|
||||
extra_body={ # pass in any provider-specific param, if not supported by openai, https://docs.litellm.ai/docs/completion/input#provider-specific-params
|
||||
"metadata": {
|
||||
"permissions": {
|
||||
"enable_llm_guard_check": True # 👈 KEY CHANGE
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="curl" label="Curl Request">
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
|
||||
--data '{"model": "fake-openai-endpoint", "messages": [
|
||||
{"role": "system", "content": "Be helpful"},
|
||||
{"role": "user", "content": "What do you know?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Content Moderation with LlamaGuard
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Load Balancing - Config Setup
|
||||
# Multiple Instances
|
||||
Load balance multiple instances of the same model
|
||||
|
||||
The proxy will handle routing requests (using LiteLLM's Router). **Set `rpm` in the config if you want maximize throughput**
|
||||
|
@ -10,75 +10,6 @@ For more details on routing strategies / params, see [Routing](../routing.md)
|
|||
|
||||
:::
|
||||
|
||||
## Quick Start - Load Balancing
|
||||
### Step 1 - Set deployments on config
|
||||
|
||||
**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/<your-deployment-name>
|
||||
api_base: <your-azure-endpoint>
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-large
|
||||
api_base: https://openai-france-1234.openai.azure.com/
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 1440
|
||||
```
|
||||
|
||||
### Step 2: Start Proxy with config
|
||||
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
### Step 3: Use proxy - Call a model group [Load Balancing]
|
||||
Curl Command
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
### Usage - Call a specific model deployment
|
||||
If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model`
|
||||
|
||||
In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "azure/gpt-turbo-small-ca",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
## Load Balancing using multiple litellm instances (Kubernetes, Auto Scaling)
|
||||
|
||||
LiteLLM Proxy supports sharing rpm/tpm shared across multiple litellm instances, pass `redis_host`, `redis_password` and `redis_port` to enable this. (LiteLLM will use Redis to track rpm/tpm usage )
|
||||
|
|
|
@ -9,9 +9,9 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
|
|||
|
||||
- [Async Custom Callbacks](#custom-callback-class-async)
|
||||
- [Async Custom Callback APIs](#custom-callback-apis-async)
|
||||
- [Logging to DataDog](#logging-proxy-inputoutput---datadog)
|
||||
- [Logging to Langfuse](#logging-proxy-inputoutput---langfuse)
|
||||
- [Logging to s3 Buckets](#logging-proxy-inputoutput---s3-buckets)
|
||||
- [Logging to DataDog](#logging-proxy-inputoutput---datadog)
|
||||
- [Logging to DynamoDB](#logging-proxy-inputoutput---dynamodb)
|
||||
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
||||
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
||||
|
@ -539,6 +539,36 @@ print(response)
|
|||
</Tabs>
|
||||
|
||||
|
||||
### Team based Logging to Langfuse
|
||||
|
||||
**Example:**
|
||||
|
||||
This config would send langfuse logs to 2 different langfuse projects, based on the team id
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
default_team_settings:
|
||||
- team_id: my-secret-project
|
||||
success_callback: ["langfuse"]
|
||||
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
|
||||
langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
|
||||
- team_id: ishaans-secret-project
|
||||
success_callback: ["langfuse"]
|
||||
langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_2 # Project 2
|
||||
langfuse_secret: os.environ/LANGFUSE_SECRET_2 # Project 2
|
||||
```
|
||||
|
||||
Now, when you [generate keys](./virtual_keys.md) for this team-id
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"team_id": "ishaans-secret-project"}'
|
||||
```
|
||||
|
||||
All requests made with these keys will log data to their team-specific logging.
|
||||
|
||||
## Logging Proxy Input/Output - DataDog
|
||||
We will use the `--config` to set `litellm.success_callback = ["datadog"]` this will log all successfull LLM calls to DataDog
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ Expected Performance in Production
|
|||
| `/chat/completions` Requests/hour | `126K` |
|
||||
|
||||
|
||||
## 1. Switch of Debug Logging
|
||||
## 1. Switch off Debug Logging
|
||||
|
||||
Remove `set_verbose: True` from your config.yaml
|
||||
```yaml
|
||||
|
@ -40,7 +40,7 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
|||
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
||||
```
|
||||
|
||||
## 2. Batch write spend updates every 60s
|
||||
## 3. Batch write spend updates every 60s
|
||||
|
||||
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally.
|
||||
|
||||
|
@ -49,11 +49,35 @@ In production, we recommend using a longer interval period of 60s. This reduces
|
|||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
```
|
||||
|
||||
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
|
||||
|
||||
## 3. Move spend logs to separate server
|
||||
When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
|
||||
|
||||
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
|
||||
|
||||
Recommended to do this for prod:
|
||||
|
||||
```yaml
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
# redis_url: "os.environ/REDIS_URL"
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
```
|
||||
|
||||
## 5. Switch off resetting budgets
|
||||
|
||||
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_reset_budget: true
|
||||
```
|
||||
|
||||
## 6. Move spend logs to separate server (BETA)
|
||||
|
||||
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server.
|
||||
|
||||
|
@ -141,24 +165,6 @@ A t2.micro should be sufficient to handle 1k logs / minute on this server.
|
|||
|
||||
This consumes at max 120MB, and <0.1 vCPU.
|
||||
|
||||
## 4. Switch off resetting budgets
|
||||
|
||||
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_spend_logs: true
|
||||
disable_reset_budget: true
|
||||
```
|
||||
|
||||
## 5. Switch of `litellm.telemetry`
|
||||
|
||||
Switch of all telemetry tracking done by litellm
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
telemetry: False
|
||||
```
|
||||
|
||||
## Machine Specifications to Deploy LiteLLM
|
||||
|
||||
| Service | Spec | CPUs | Memory | Architecture | Version|
|
||||
|
|
|
@ -14,6 +14,7 @@ model_list:
|
|||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
success_callback: ["prometheus"]
|
||||
failure_callback: ["prometheus"]
|
||||
```
|
||||
|
||||
Start the proxy
|
||||
|
@ -48,6 +49,26 @@ http://localhost:4000/metrics
|
|||
|
||||
| Metric Name | Description |
|
||||
|----------------------|--------------------------------------|
|
||||
| `litellm_requests_metric` | Number of requests made, per `"user", "key", "model"` |
|
||||
| `litellm_spend_metric` | Total Spend, per `"user", "key", "model"` |
|
||||
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model"` |
|
||||
| `litellm_requests_metric` | Number of requests made, per `"user", "key", "model", "team", "end-user"` |
|
||||
| `litellm_spend_metric` | Total Spend, per `"user", "key", "model", "team", "end-user"` |
|
||||
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
|
||||
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
|
||||
|
||||
## Monitor System Health
|
||||
|
||||
To monitor the health of litellm adjacent services (redis / postgres), do:
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
service_callback: ["prometheus_system"]
|
||||
```
|
||||
|
||||
| Metric Name | Description |
|
||||
|----------------------|--------------------------------------|
|
||||
| `litellm_redis_latency` | histogram latency for redis calls |
|
||||
| `litellm_redis_fails` | Number of failed redis calls |
|
||||
| `litellm_self_latency` | Histogram latency for successful litellm api call |
|
|
@ -348,6 +348,29 @@ query_result = embeddings.embed_query(text)
|
|||
|
||||
print(f"TITAN EMBEDDINGS")
|
||||
print(query_result[:5])
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="litellm" label="LiteLLM SDK">
|
||||
|
||||
This is **not recommended**. There is duplicate logic as the proxy also uses the sdk, which might lead to unexpected errors.
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
response = completion(
|
||||
model="openai/gpt-3.5-turbo",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
],
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
@ -438,7 +461,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
|
|||
),
|
||||
```
|
||||
|
||||
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial.
|
||||
Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="aider" label="Aider">
|
||||
|
@ -551,4 +574,3 @@ No Logs
|
|||
```shell
|
||||
export LITELLM_LOG=None
|
||||
```
|
||||
|
||||
|
|
|
@ -2,7 +2,9 @@ import Image from '@theme/IdealImage';
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Fallbacks, Retries, Timeouts, Cooldowns
|
||||
# 🔥 Fallbacks, Retries, Timeouts, Load Balancing
|
||||
|
||||
Retry call with multiple instances of the same model.
|
||||
|
||||
If a call fails after num_retries, fall back to another model group.
|
||||
|
||||
|
@ -10,6 +12,77 @@ If the error is a context window exceeded error, fall back to a larger model gro
|
|||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
|
||||
|
||||
## Quick Start - Load Balancing
|
||||
### Step 1 - Set deployments on config
|
||||
|
||||
**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/<your-deployment-name>
|
||||
api_base: <your-azure-endpoint>
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-large
|
||||
api_base: https://openai-france-1234.openai.azure.com/
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 1440
|
||||
```
|
||||
|
||||
### Step 2: Start Proxy with config
|
||||
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
### Step 3: Use proxy - Call a model group [Load Balancing]
|
||||
Curl Command
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
### Usage - Call a specific model deployment
|
||||
If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model`
|
||||
|
||||
In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "azure/gpt-turbo-small-ca",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
## Fallbacks + Retries + Timeouts + Cooldowns
|
||||
|
||||
**Set via config**
|
||||
```yaml
|
||||
model_list:
|
||||
|
@ -63,7 +136,158 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
'
|
||||
```
|
||||
|
||||
## Custom Timeouts, Stream Timeouts - Per Model
|
||||
### Test it!
|
||||
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"model": "zephyr-beta", # 👈 MODEL NAME to fallback from
|
||||
"messages": [
|
||||
{"role": "user", "content": "what color is red"}
|
||||
],
|
||||
"mock_testing_fallbacks": true
|
||||
}'
|
||||
```
|
||||
|
||||
## Advanced - Context Window Fallbacks
|
||||
|
||||
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/c9e6b05cfb20dfb17272218e2555d6b496c47f6f/litellm/router.py#L2163)
|
||||
|
||||
**1. Setup config**
|
||||
|
||||
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with azure/.
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="same-group" label="Same Group">
|
||||
|
||||
Filter older instances of a model (e.g. gpt-3.5-turbo) with smaller context windows
|
||||
|
||||
```yaml
|
||||
router_settings:
|
||||
enable_pre_call_checks: true # 1. Enable pre-call checks
|
||||
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
|
||||
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-1106
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
text = "What is the meaning of 42?" * 5000
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages = [
|
||||
{"role": "system", "content": text},
|
||||
{"role": "user", "content": "Who was Alexander?"},
|
||||
],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
|
||||
|
||||
Fallback to larger models if current model is too small.
|
||||
|
||||
```yaml
|
||||
router_settings:
|
||||
enable_pre_call_checks: true # 1. Enable pre-call checks
|
||||
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo-small
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
|
||||
|
||||
- model_name: gpt-3.5-turbo-large
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-1106
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
- model_name: claude-opus
|
||||
litellm_params:
|
||||
model: claude-3-opus-20240229
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
context_window_fallbacks: [{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}]
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
text = "What is the meaning of 42?" * 5000
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages = [
|
||||
{"role": "system", "content": text},
|
||||
{"role": "user", "content": "Who was Alexander?"},
|
||||
],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Advanced - Custom Timeouts, Stream Timeouts - Per Model
|
||||
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
||||
```yaml
|
||||
model_list:
|
||||
|
@ -92,7 +316,7 @@ $ litellm --config /path/to/config.yaml
|
|||
```
|
||||
|
||||
|
||||
## Setting Dynamic Timeouts - Per Request
|
||||
## Advanced - Setting Dynamic Timeouts - Per Request
|
||||
|
||||
LiteLLM Proxy supports setting a `timeout` per request
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ Now, when you [generate keys](./virtual_keys.md) for this team-id
|
|||
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-D '{"team_id": "ishaans-secret-project"}'
|
||||
-d '{"team_id": "ishaans-secret-project"}'
|
||||
```
|
||||
|
||||
All requests made with these keys will log data to their team-specific logging.
|
||||
|
|
|
@ -9,6 +9,7 @@ Use JWT's to auth admins / projects into the proxy.
|
|||
|
||||
This is a new feature, and subject to changes based on feedback.
|
||||
|
||||
*UPDATE*: This will be moving to the [enterprise tier](./enterprise.md), once it's out of beta (~by end of April).
|
||||
:::
|
||||
|
||||
## Usage
|
||||
|
@ -107,6 +108,34 @@ general_settings:
|
|||
litellm_jwtauth:
|
||||
admin_jwt_scope: "litellm-proxy-admin"
|
||||
```
|
||||
|
||||
## Advanced - Spend Tracking (User / Team / Org)
|
||||
|
||||
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
admin_jwt_scope: "litellm-proxy-admin"
|
||||
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
|
||||
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
|
||||
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
|
||||
```
|
||||
|
||||
Expected JWT:
|
||||
|
||||
```
|
||||
{
|
||||
"client_id": "my-unique-team",
|
||||
"sub": "my-unique-user",
|
||||
"org_id": "my-unique-org"
|
||||
}
|
||||
```
|
||||
|
||||
Now litellm will automatically update the spend for the user/team/org in the db for each call.
|
||||
|
||||
### JWT Scopes
|
||||
|
||||
Here's what scopes on JWT-Auth tokens look like
|
||||
|
@ -149,7 +178,7 @@ general_settings:
|
|||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
...
|
||||
team_jwt_scope: "litellm-team" # 👈 Set JWT Scope string
|
||||
team_id_jwt_field: "litellm-team" # 👈 Set field in the JWT token that stores the team ID
|
||||
team_allowed_routes: ["/v1/chat/completions"] # 👈 Set accepted routes
|
||||
```
|
||||
|
||||
|
|
|
@ -56,6 +56,9 @@ On accessing the LiteLLM UI, you will be prompted to enter your username, passwo
|
|||
|
||||
## ✨ Enterprise Features
|
||||
|
||||
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
|
||||
|
||||
|
||||
### Setup SSO/Auth for UI
|
||||
|
||||
#### Step 1: Set upperbounds for keys
|
||||
|
|
|
@ -121,6 +121,9 @@ from langchain.prompts.chat import (
|
|||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "anything"
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:4000",
|
||||
|
|
|
@ -435,7 +435,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
|
|||
),
|
||||
```
|
||||
|
||||
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial.
|
||||
Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
|
||||
</TabItem>
|
||||
<TabItem value="aider" label="Aider">
|
||||
|
||||
|
@ -815,4 +815,3 @@ Thread Stats Avg Stdev Max +/- Stdev
|
|||
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
|
||||
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
|
||||
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai
|
||||
|
||||
|
|
|
@ -95,12 +95,129 @@ print(response)
|
|||
- `router.image_generation()` - completion calls in OpenAI `/v1/images/generations` endpoint format
|
||||
- `router.aimage_generation()` - async image generation calls
|
||||
|
||||
### Advanced
|
||||
## Advanced - Routing Strategies
|
||||
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based
|
||||
|
||||
Router provides 4 strategies for routing your calls across multiple deployments:
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="usage-based-v2" label="Rate-Limit Aware v2 (ASYNC)">
|
||||
|
||||
**🎉 NEW** This is an async implementation of usage-based-routing.
|
||||
|
||||
**Filters out deployment if tpm/rpm limit exceeded** - If you pass in the deployment's tpm/rpm limits.
|
||||
|
||||
Routes to **deployment with lowest TPM usage** for that minute.
|
||||
|
||||
In production, we use Redis to track usage (TPM/RPM) across multiple deployments. This implementation uses **async redis calls** (redis.incr and redis.mget).
|
||||
|
||||
For Azure, your RPM = TPM/6.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="sdk">
|
||||
|
||||
```python
|
||||
from litellm import Router
|
||||
|
||||
|
||||
model_list = [{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # model alias
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2", # actual model name
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 1000,
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 1000,
|
||||
}]
|
||||
router = Router(model_list=model_list,
|
||||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
|
||||
enable_pre_call_check=True, # enables router rate limits for concurrent calls
|
||||
)
|
||||
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="proxy">
|
||||
|
||||
**1. Set strategy in config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo # model alias
|
||||
litellm_params: # params for litellm completion/embedding call
|
||||
model: azure/chatgpt-v-2 # actual model name
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: os.environ/AZURE_API_VERSION
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
tpm: 100000
|
||||
rpm: 10000
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params: # params for litellm completion/embedding call
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.getenv(OPENAI_API_KEY)
|
||||
tpm: 100000
|
||||
rpm: 1000
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
|
||||
redis_host: <your-redis-host>
|
||||
redis_password: <your-redis-password>
|
||||
redis_port: <your-redis-port>
|
||||
enable_pre_call_check: true
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://localhost:4000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="latency-based" label="Latency-Based">
|
||||
|
||||
|
||||
|
@ -117,7 +234,10 @@ import asyncio
|
|||
model_list = [{ ... }]
|
||||
|
||||
# init router
|
||||
router = Router(model_list=model_list, routing_strategy="latency-based-routing") # 👈 set routing strategy
|
||||
router = Router(model_list=model_list,
|
||||
routing_strategy="latency-based-routing",# 👈 set routing strategy
|
||||
enable_pre_call_check=True, # enables router rate limits for concurrent calls
|
||||
)
|
||||
|
||||
## CALL 1+2
|
||||
tasks = []
|
||||
|
@ -159,7 +279,7 @@ router_settings:
|
|||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="simple-shuffle" label="(Default) Weighted Pick">
|
||||
<TabItem value="simple-shuffle" label="(Default) Weighted Pick (Async)">
|
||||
|
||||
**Default** Picks a deployment based on the provided **Requests per minute (rpm) or Tokens per minute (tpm)**
|
||||
|
||||
|
@ -257,8 +377,9 @@ router = Router(model_list=model_list,
|
|||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
routing_strategy="usage-based-routing")
|
||||
|
||||
routing_strategy="usage-based-routing"
|
||||
enable_pre_call_check=True, # enables router rate limits for concurrent calls
|
||||
)
|
||||
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||
|
@ -555,7 +676,11 @@ router = Router(model_list: Optional[list] = None,
|
|||
|
||||
## Pre-Call Checks (Context Window)
|
||||
|
||||
Enable pre-call checks to filter out deployments with context window limit < messages for a call.
|
||||
Enable pre-call checks to filter out:
|
||||
1. deployments with context window limit < messages for a call.
|
||||
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[
|
||||
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
|
||||
])`)
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
@ -567,10 +692,14 @@ from litellm import Router
|
|||
router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set to True
|
||||
```
|
||||
|
||||
**2. (Azure-only) Set base model**
|
||||
|
||||
**2. Set Model List**
|
||||
|
||||
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="same-group" label="Same Group">
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
{
|
||||
|
@ -582,7 +711,7 @@ model_list = [
|
|||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"model_info": {
|
||||
"base_model": "azure/gpt-35-turbo", # 👈 SET BASE MODEL
|
||||
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -593,8 +722,51 @@ model_list = [
|
|||
},
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-small", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"model_info": {
|
||||
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-large", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-opus",
|
||||
"litellm_params": { call
|
||||
"model": "claude-3-opus-20240229",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```python
|
||||
|
@ -646,60 +818,9 @@ print(f"response: {response}")
|
|||
</TabItem>
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
**1. Setup config**
|
||||
|
||||
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with azure/.
|
||||
|
||||
```yaml
|
||||
router_settings:
|
||||
enable_pre_call_checks: true # 1. Enable pre-call checks
|
||||
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
|
||||
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-1106
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
text = "What is the meaning of 42?" * 5000
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages = [
|
||||
{"role": "system", "content": text},
|
||||
{"role": "user", "content": "Who was Alexander?"},
|
||||
],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
:::info
|
||||
Go [here](./proxy/reliability.md#advanced---context-window-fallbacks) for how to do this on the proxy
|
||||
:::
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
|
|
@ -310,7 +310,7 @@ In the [config.py](https://continue.dev/docs/reference/Models/openai) set this a
|
|||
),
|
||||
```
|
||||
|
||||
Credits [@vividfog](https://github.com/jmorganca/ollama/issues/305#issuecomment-1751848077) for this tutorial.
|
||||
Credits [@vividfog](https://github.com/ollama/ollama/issues/305#issuecomment-1751848077) for this tutorial.
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="aider" label="Aider">
|
||||
|
@ -1351,5 +1351,3 @@ LiteLLM proxy adds **0.00325 seconds** latency as compared to using the Raw Open
|
|||
```shell
|
||||
litellm --telemetry False
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -105,6 +105,12 @@ const config = {
|
|||
label: 'Enterprise',
|
||||
to: "docs/enterprise"
|
||||
},
|
||||
{
|
||||
sidebarId: 'tutorialSidebar',
|
||||
position: 'left',
|
||||
label: '🚀 Hosted',
|
||||
to: "docs/hosted"
|
||||
},
|
||||
{
|
||||
href: 'https://github.com/BerriAI/litellm',
|
||||
label: 'GitHub',
|
||||
|
|
BIN
docs/my-website/img/gcp_acc_1.png
Normal file
After Width: | Height: | Size: 91 KiB |
BIN
docs/my-website/img/gcp_acc_2.png
Normal file
After Width: | Height: | Size: 298 KiB |
BIN
docs/my-website/img/gcp_acc_3.png
Normal file
After Width: | Height: | Size: 208 KiB |
BIN
docs/my-website/img/litellm_hosted_ui_add_models.png
Normal file
After Width: | Height: | Size: 398 KiB |
BIN
docs/my-website/img/litellm_hosted_ui_create_key.png
Normal file
After Width: | Height: | Size: 496 KiB |
BIN
docs/my-website/img/litellm_hosted_ui_router.png
Normal file
After Width: | Height: | Size: 348 KiB |
BIN
docs/my-website/img/litellm_hosted_usage_dashboard.png
Normal file
After Width: | Height: | Size: 460 KiB |
|
@ -31,24 +31,26 @@ const sidebars = {
|
|||
"proxy/quick_start",
|
||||
"proxy/deploy",
|
||||
"proxy/prod",
|
||||
"proxy/configs",
|
||||
{
|
||||
type: "link",
|
||||
label: "📖 All Endpoints",
|
||||
label: "📖 All Endpoints (Swagger)",
|
||||
href: "https://litellm-api.up.railway.app/",
|
||||
},
|
||||
"proxy/enterprise",
|
||||
"proxy/user_keys",
|
||||
"proxy/virtual_keys",
|
||||
"proxy/demo",
|
||||
"proxy/configs",
|
||||
"proxy/reliability",
|
||||
"proxy/users",
|
||||
"proxy/user_keys",
|
||||
"proxy/enterprise",
|
||||
"proxy/virtual_keys",
|
||||
"proxy/team_based_routing",
|
||||
"proxy/ui",
|
||||
"proxy/cost_tracking",
|
||||
"proxy/token_auth",
|
||||
{
|
||||
type: "category",
|
||||
label: "🔥 Load Balancing",
|
||||
items: ["proxy/load_balancing", "proxy/reliability"],
|
||||
label: "Extra Load Balancing",
|
||||
items: ["proxy/load_balancing"],
|
||||
},
|
||||
"proxy/model_management",
|
||||
"proxy/health",
|
||||
|
@ -61,7 +63,7 @@ const sidebars = {
|
|||
label: "Logging, Alerting",
|
||||
items: ["proxy/logging", "proxy/alerting", "proxy/streaming_logging"],
|
||||
},
|
||||
"proxy/grafana_metrics",
|
||||
"proxy/prometheus",
|
||||
"proxy/call_hooks",
|
||||
"proxy/rules",
|
||||
"proxy/cli",
|
||||
|
@ -84,6 +86,7 @@ const sidebars = {
|
|||
"completion/stream",
|
||||
"completion/message_trimming",
|
||||
"completion/function_call",
|
||||
"completion/vision",
|
||||
"completion/model_alias",
|
||||
"completion/batching",
|
||||
"completion/mock_requests",
|
||||
|
@ -113,6 +116,7 @@ const sidebars = {
|
|||
},
|
||||
items: [
|
||||
"providers/openai",
|
||||
"providers/text_completion_openai",
|
||||
"providers/openai_compatible",
|
||||
"providers/azure",
|
||||
"providers/azure_ai",
|
||||
|
@ -162,7 +166,6 @@ const sidebars = {
|
|||
"debugging/local_debugging",
|
||||
"observability/callbacks",
|
||||
"observability/custom_callback",
|
||||
"observability/lunary_integration",
|
||||
"observability/langfuse_integration",
|
||||
"observability/sentry",
|
||||
"observability/promptlayer_integration",
|
||||
|
@ -171,6 +174,8 @@ const sidebars = {
|
|||
"observability/slack_integration",
|
||||
"observability/traceloop_integration",
|
||||
"observability/athina_integration",
|
||||
"observability/lunary_integration",
|
||||
"observability/athina_integration",
|
||||
"observability/helicone_integration",
|
||||
"observability/supabase_integration",
|
||||
`observability/telemetry`,
|
||||
|
|
|
@ -16,7 +16,7 @@ However, we also expose 3 public helper functions to calculate token usage acros
|
|||
```python
|
||||
from litellm import token_counter
|
||||
|
||||
messages = [{"user": "role", "content": "Hey, how's it going"}]
|
||||
messages = [{"role": "user", "content": "Hey, how's it going"}]
|
||||
print(token_counter(model="gpt-3.5-turbo", messages=messages))
|
||||
```
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:
|
||||
if self.llm_guard_mode == "key-specific":
|
||||
# check if llm guard enabled for specific keys only
|
||||
self.print_verbose(
|
||||
|
@ -108,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
return True
|
||||
elif self.llm_guard_mode == "all":
|
||||
return True
|
||||
elif self.llm_guard_mode == "request-specific":
|
||||
self.print_verbose(f"received metadata: {data.get('metadata', {})}")
|
||||
metadata = data.get("metadata", {})
|
||||
permissions = metadata.get("permissions", {})
|
||||
if (
|
||||
"enable_llm_guard_check" in permissions
|
||||
and permissions["enable_llm_guard_check"] == True
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def async_moderation_hook(
|
||||
|
@ -126,7 +135,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
|
||||
)
|
||||
|
||||
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data)
|
||||
if _proceed == False:
|
||||
return
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Enterprise Proxy Util Endpoints
|
||||
from litellm._logging import verbose_logger
|
||||
import collections
|
||||
|
||||
|
||||
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
||||
|
@ -17,6 +18,48 @@ async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
|||
return response
|
||||
|
||||
|
||||
async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
||||
response = await prisma_client.db.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||
DATE(s."startTime") AS spend_date,
|
||||
COUNT(*) AS log_count,
|
||||
SUM(spend) AS total_spend
|
||||
FROM "LiteLLM_SpendLogs" s
|
||||
WHERE s."startTime" >= current_date - interval '30 days'
|
||||
GROUP BY individual_request_tag, spend_date
|
||||
ORDER BY spend_date;
|
||||
"""
|
||||
)
|
||||
|
||||
# print("tags - spend")
|
||||
# print(response)
|
||||
# Bar Chart 1 - Spend per tag - Top 10 tags by spend
|
||||
total_spend_per_tag = collections.defaultdict(float)
|
||||
total_requests_per_tag = collections.defaultdict(int)
|
||||
for row in response:
|
||||
tag_name = row["individual_request_tag"]
|
||||
tag_spend = row["total_spend"]
|
||||
|
||||
total_spend_per_tag[tag_name] += tag_spend
|
||||
total_requests_per_tag[tag_name] += row["log_count"]
|
||||
|
||||
sorted_tags = sorted(total_spend_per_tag.items(), key=lambda x: x[1], reverse=True)
|
||||
# convert to ui format
|
||||
ui_tags = []
|
||||
for tag in sorted_tags:
|
||||
ui_tags.append(
|
||||
{
|
||||
"name": tag[0],
|
||||
"value": tag[1],
|
||||
"log_count": total_requests_per_tag[tag[0]],
|
||||
}
|
||||
)
|
||||
|
||||
return {"top_10_tags": ui_tags}
|
||||
|
||||
|
||||
async def view_spend_logs_from_clickhouse(
|
||||
api_key=None, user_id=None, request_id=None, start_date=None, end_date=None
|
||||
):
|
||||
|
|
8
litellm-js/spend-logs/package-lock.json
generated
|
@ -6,7 +6,7 @@
|
|||
"": {
|
||||
"dependencies": {
|
||||
"@hono/node-server": "^1.9.0",
|
||||
"hono": "^4.1.5"
|
||||
"hono": "^4.2.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.17",
|
||||
|
@ -463,9 +463,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.1.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.1.5.tgz",
|
||||
"integrity": "sha512-3ChJiIoeCxvkt6vnkxJagplrt1YZg3NyNob7ssVeK2PUqEINp4q1F94HzFnvY9QE8asVmbW5kkTDlyWylfg2vg==",
|
||||
"version": "4.2.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
|
||||
"integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
|
||||
"engines": {
|
||||
"node": ">=16.0.0"
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
},
|
||||
"dependencies": {
|
||||
"@hono/node-server": "^1.9.0",
|
||||
"hono": "^4.1.5"
|
||||
"hono": "^4.2.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.17",
|
||||
|
|
|
@ -3,7 +3,11 @@ import threading, requests, os
|
|||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||
from litellm.caching import Cache
|
||||
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
|
||||
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
||||
from litellm.proxy._types import (
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
LiteLLM_UpperboundKeyGenerateParams,
|
||||
)
|
||||
import httpx
|
||||
import dotenv
|
||||
|
||||
|
@ -12,10 +16,24 @@ dotenv.load_dotenv()
|
|||
if set_verbose == True:
|
||||
_turn_on_debug()
|
||||
#############################################
|
||||
### Callbacks /Logging / Success / Failure Handlers ###
|
||||
input_callback: List[Union[str, Callable]] = []
|
||||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
service_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
_langfuse_default_tags: Optional[
|
||||
List[
|
||||
Literal[
|
||||
"user_api_key_alias",
|
||||
"user_api_key_user_id",
|
||||
"user_api_key_user_email",
|
||||
"user_api_key_team_alias",
|
||||
"semantic-similarity",
|
||||
"proxy_base_url",
|
||||
]
|
||||
]
|
||||
] = None
|
||||
_async_input_callback: List[Callable] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
|
@ -27,6 +45,8 @@ _async_failure_callback: List[Callable] = (
|
|||
) # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
## end of callbacks #############
|
||||
|
||||
email: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
|
@ -46,6 +66,7 @@ replicate_key: Optional[str] = None
|
|||
cohere_key: Optional[str] = None
|
||||
maritalk_key: Optional[str] = None
|
||||
ai21_key: Optional[str] = None
|
||||
ollama_key: Optional[str] = None
|
||||
openrouter_key: Optional[str] = None
|
||||
huggingface_key: Optional[str] = None
|
||||
vertex_project: Optional[str] = None
|
||||
|
@ -56,6 +77,7 @@ baseten_key: Optional[str] = None
|
|||
aleph_alpha_key: Optional[str] = None
|
||||
nlp_cloud_key: Optional[str] = None
|
||||
use_client: bool = False
|
||||
ssl_verify: bool = True
|
||||
disable_streaming_logging: bool = False
|
||||
### GUARDRAILS ###
|
||||
llamaguard_model_name: Optional[str] = None
|
||||
|
@ -64,7 +86,7 @@ google_moderation_confidence_threshold: Optional[float] = None
|
|||
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||
blocked_user_list: Optional[Union[str, List]] = None
|
||||
banned_keywords_list: Optional[Union[str, List]] = None
|
||||
llm_guard_mode: Literal["all", "key-specific"] = "all"
|
||||
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||
##################
|
||||
logging: bool = True
|
||||
caching: bool = (
|
||||
|
@ -76,6 +98,8 @@ caching_with_models: bool = (
|
|||
cache: Optional[Cache] = (
|
||||
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
)
|
||||
default_in_memory_ttl: Optional[float] = None
|
||||
default_redis_ttl: Optional[float] = None
|
||||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
|
@ -170,7 +194,7 @@ dynamodb_table_name: Optional[str] = None
|
|||
s3_callback_params: Optional[Dict] = None
|
||||
generic_logger_headers: Optional[Dict] = None
|
||||
default_key_generate_params: Optional[Dict] = None
|
||||
upperbound_key_generate_params: Optional[Dict] = None
|
||||
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
||||
default_user_params: Optional[Dict] = None
|
||||
default_team_settings: Optional[List] = None
|
||||
max_user_budget: Optional[float] = None
|
||||
|
@ -258,6 +282,7 @@ open_ai_chat_completion_models: List = []
|
|||
open_ai_text_completion_models: List = []
|
||||
cohere_models: List = []
|
||||
cohere_chat_models: List = []
|
||||
mistral_chat_models: List = []
|
||||
anthropic_models: List = []
|
||||
openrouter_models: List = []
|
||||
vertex_language_models: List = []
|
||||
|
@ -267,6 +292,7 @@ vertex_code_chat_models: List = []
|
|||
vertex_text_models: List = []
|
||||
vertex_code_text_models: List = []
|
||||
vertex_embedding_models: List = []
|
||||
vertex_anthropic_models: List = []
|
||||
ai21_models: List = []
|
||||
nlp_cloud_models: List = []
|
||||
aleph_alpha_models: List = []
|
||||
|
@ -282,6 +308,8 @@ for key, value in model_cost.items():
|
|||
cohere_models.append(key)
|
||||
elif value.get("litellm_provider") == "cohere_chat":
|
||||
cohere_chat_models.append(key)
|
||||
elif value.get("litellm_provider") == "mistral":
|
||||
mistral_chat_models.append(key)
|
||||
elif value.get("litellm_provider") == "anthropic":
|
||||
anthropic_models.append(key)
|
||||
elif value.get("litellm_provider") == "openrouter":
|
||||
|
@ -300,6 +328,9 @@ for key, value in model_cost.items():
|
|||
vertex_code_chat_models.append(key)
|
||||
elif value.get("litellm_provider") == "vertex_ai-embedding-models":
|
||||
vertex_embedding_models.append(key)
|
||||
elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
|
||||
key = key.replace("vertex_ai/", "")
|
||||
vertex_anthropic_models.append(key)
|
||||
elif value.get("litellm_provider") == "ai21":
|
||||
ai21_models.append(key)
|
||||
elif value.get("litellm_provider") == "nlp_cloud":
|
||||
|
@ -346,7 +377,7 @@ replicate_models: List = [
|
|||
"replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
|
||||
"joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe",
|
||||
# Flan T-5
|
||||
"daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f"
|
||||
"daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f",
|
||||
# Others
|
||||
"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
|
||||
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
|
||||
|
@ -447,6 +478,7 @@ model_list = (
|
|||
+ deepinfra_models
|
||||
+ perplexity_models
|
||||
+ maritalk_models
|
||||
+ vertex_language_models
|
||||
)
|
||||
|
||||
provider_list: List = [
|
||||
|
@ -568,6 +600,7 @@ from .utils import (
|
|||
completion_cost,
|
||||
supports_function_calling,
|
||||
supports_parallel_function_calling,
|
||||
supports_vision,
|
||||
get_litellm_params,
|
||||
Logging,
|
||||
acreate,
|
||||
|
@ -585,6 +618,7 @@ from .utils import (
|
|||
_should_retry,
|
||||
get_secret,
|
||||
get_supported_openai_params,
|
||||
get_api_base,
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
@ -600,6 +634,7 @@ from .llms.nlp_cloud import NLPCloudConfig
|
|||
from .llms.aleph_alpha import AlephAlphaConfig
|
||||
from .llms.petals import PetalsConfig
|
||||
from .llms.vertex_ai import VertexAIConfig
|
||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||
from .llms.sagemaker import SagemakerConfig
|
||||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
|
|
|
@ -32,6 +32,25 @@ def _get_redis_kwargs():
|
|||
return available_args
|
||||
|
||||
|
||||
def _get_redis_url_kwargs(client=None):
|
||||
if client is None:
|
||||
client = redis.Redis.from_url
|
||||
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {
|
||||
"self",
|
||||
"connection_pool",
|
||||
"retry",
|
||||
}
|
||||
|
||||
include_args = ["url"]
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_env_kwarg_mapping():
|
||||
PREFIX = "REDIS_"
|
||||
|
||||
|
@ -91,27 +110,39 @@ def _get_redis_client_logic(**env_overrides):
|
|||
redis_kwargs.pop("password", None)
|
||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis_kwargs
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop(
|
||||
"connection_pool", None
|
||||
) # redis.from_url doesn't support setting your own connection pool
|
||||
return redis.Redis.from_url(**redis_kwargs)
|
||||
args = _get_redis_url_kwargs()
|
||||
url_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
return redis.Redis.from_url(**url_kwargs)
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
def get_redis_async_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop(
|
||||
"connection_pool", None
|
||||
) # redis.from_url doesn't support setting your own connection pool
|
||||
return async_redis.Redis.from_url(**redis_kwargs)
|
||||
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
|
||||
url_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
else:
|
||||
litellm.print_verbose(
|
||||
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
|
||||
arg
|
||||
)
|
||||
)
|
||||
return async_redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
return async_redis.Redis(
|
||||
socket_timeout=5,
|
||||
**redis_kwargs,
|
||||
|
@ -124,4 +155,9 @@ def get_redis_connection_pool(**env_overrides):
|
|||
return async_redis.BlockingConnectionPool.from_url(
|
||||
timeout=5, url=redis_kwargs["url"]
|
||||
)
|
||||
connection_class = async_redis.Connection
|
||||
if "ssl" in redis_kwargs and redis_kwargs["ssl"] is not None:
|
||||
connection_class = async_redis.SSLConnection
|
||||
redis_kwargs.pop("ssl", None)
|
||||
redis_kwargs["connection_class"] = connection_class
|
||||
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
|
||||
|
|
130
litellm/_service_logger.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
import litellm, traceback
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from .types.services import ServiceTypes, ServiceLoggerPayload
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
from .integrations.custom_logger import CustomLogger
|
||||
from datetime import timedelta
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ServiceLogging(CustomLogger):
|
||||
"""
|
||||
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
|
||||
"""
|
||||
|
||||
def __init__(self, mock_testing: bool = False) -> None:
|
||||
self.mock_testing = mock_testing
|
||||
self.mock_testing_sync_success_hook = 0
|
||||
self.mock_testing_async_success_hook = 0
|
||||
self.mock_testing_sync_failure_hook = 0
|
||||
self.mock_testing_async_failure_hook = 0
|
||||
if "prometheus_system" in litellm.service_callback:
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
|
||||
def service_success_hook(
|
||||
self, service: ServiceTypes, duration: float, call_type: str
|
||||
):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_success_hook += 1
|
||||
|
||||
def service_failure_hook(
|
||||
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
|
||||
):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_failure_hook += 1
|
||||
|
||||
async def async_service_success_hook(
|
||||
self, service: ServiceTypes, duration: float, call_type: str
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is successful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_success_hook += 1
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=False,
|
||||
error=None,
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
)
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.prometheusServicesLogger.async_service_success_hook(
|
||||
payload=payload
|
||||
)
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
duration: float,
|
||||
error: Union[str, Exception],
|
||||
call_type: str,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is unsuccessful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_failure_hook += 1
|
||||
|
||||
error_message = ""
|
||||
if isinstance(error, Exception):
|
||||
error_message = str(error)
|
||||
elif isinstance(error, str):
|
||||
error_message = error
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=True,
|
||||
error=error_message,
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
)
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
if self.prometheusServicesLogger is None:
|
||||
self.prometheusServicesLogger = self.prometheusServicesLogger()
|
||||
await self.prometheusServicesLogger.async_service_failure_hook(
|
||||
payload=payload
|
||||
)
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||
):
|
||||
"""
|
||||
Hook to track failed litellm-service calls
|
||||
"""
|
||||
return await super().async_post_call_failure_hook(
|
||||
original_exception, user_api_key_dict
|
||||
)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Hook to track latency for litellm proxy llm api calls
|
||||
"""
|
||||
try:
|
||||
_duration = end_time - start_time
|
||||
if isinstance(_duration, timedelta):
|
||||
_duration = _duration.total_seconds()
|
||||
elif isinstance(_duration, float):
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
"Duration={} is not a float or timedelta object. type={}".format(
|
||||
_duration, type(_duration)
|
||||
)
|
||||
) # invalid _duration value
|
||||
await self.async_service_success_hook(
|
||||
service=ServiceTypes.LITELLM,
|
||||
duration=_duration,
|
||||
call_type=kwargs["call_type"],
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
|
@ -13,6 +13,7 @@ import json, traceback, ast, hashlib
|
|||
from typing import Optional, Literal, List, Union, Any, BinaryIO
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
import traceback
|
||||
|
||||
|
||||
|
@ -81,9 +82,37 @@ class InMemoryCache(BaseCache):
|
|||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
@ -109,6 +138,8 @@ class RedisCache(BaseCache):
|
|||
**kwargs,
|
||||
):
|
||||
from ._redis import get_redis_client, get_redis_connection_pool
|
||||
from litellm._service_logger import ServiceLogging
|
||||
import redis
|
||||
|
||||
redis_kwargs = {}
|
||||
if host is not None:
|
||||
|
@ -118,10 +149,19 @@ class RedisCache(BaseCache):
|
|||
if password is not None:
|
||||
redis_kwargs["password"] = password
|
||||
|
||||
### HEALTH MONITORING OBJECT ###
|
||||
if kwargs.get("service_logger_obj", None) is not None and isinstance(
|
||||
kwargs["service_logger_obj"], ServiceLogging
|
||||
):
|
||||
self.service_logger_obj = kwargs.pop("service_logger_obj")
|
||||
else:
|
||||
self.service_logger_obj = ServiceLogging()
|
||||
|
||||
redis_kwargs.update(kwargs)
|
||||
self.redis_client = get_redis_client(**redis_kwargs)
|
||||
self.redis_kwargs = redis_kwargs
|
||||
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
|
||||
|
||||
# redis namespaces
|
||||
self.namespace = namespace
|
||||
# for high traffic, we store the redis results in memory and then batch write to redis
|
||||
|
@ -133,6 +173,16 @@ class RedisCache(BaseCache):
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
### ASYNC HEALTH PING ###
|
||||
try:
|
||||
# asyncio.get_running_loop().create_task(self.ping())
|
||||
result = asyncio.get_running_loop().create_task(self.ping())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
### SYNC HEALTH PING ###
|
||||
self.redis_client.ping()
|
||||
|
||||
def init_async_client(self):
|
||||
from ._redis import get_redis_async_client
|
||||
|
||||
|
@ -163,18 +213,101 @@ class RedisCache(BaseCache):
|
|||
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
|
||||
)
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
_redis_client = self.redis_client
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = _redis_client.incr(name=key, amount=value)
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="increment_cache",
|
||||
)
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="increment_cache",
|
||||
)
|
||||
)
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
|
||||
keys = []
|
||||
_redis_client = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
async for key in redis_client.scan_iter(match=pattern + "*", count=count):
|
||||
keys.append(key)
|
||||
if len(keys) >= count:
|
||||
break
|
||||
return keys
|
||||
start_time = time.time()
|
||||
try:
|
||||
keys = []
|
||||
_redis_client = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
async for key in redis_client.scan_iter(
|
||||
match=pattern + "*", count=count
|
||||
):
|
||||
keys.append(key)
|
||||
if len(keys) >= count:
|
||||
break
|
||||
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="async_scan_iter",
|
||||
)
|
||||
) # DO NOT SLOW DOWN CALL B/C OF THIS
|
||||
return keys
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="async_scan_iter",
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
try:
|
||||
_redis_client = self.init_async_client()
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS, duration=_duration, error=e
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
async with _redis_client as redis_client:
|
||||
ttl = kwargs.get("ttl", None)
|
||||
|
@ -186,7 +319,26 @@ class RedisCache(BaseCache):
|
|||
print_verbose(
|
||||
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="async_set_cache",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="async_set_cache",
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
||||
|
@ -200,6 +352,11 @@ class RedisCache(BaseCache):
|
|||
Use Redis Pipelines for bulk write operations
|
||||
"""
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
|
||||
print_verbose(
|
||||
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
||||
)
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
|
@ -219,8 +376,30 @@ class RedisCache(BaseCache):
|
|||
|
||||
print_verbose(f"pipeline results: {results}")
|
||||
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="async_set_cache_pipeline",
|
||||
)
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="async_set_cache_pipeline",
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
|
@ -235,7 +414,44 @@ class RedisCache(BaseCache):
|
|||
key = self.check_and_fix_namespace(key=key)
|
||||
self.redis_batch_writing_buffer.append((key, value))
|
||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||
await self.flush_cache_buffer()
|
||||
await self.flush_cache_buffer() # logging done in here
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
result = await redis_client.incr(name=key, amount=value)
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="async_increment",
|
||||
)
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="async_increment",
|
||||
)
|
||||
)
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
value,
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def flush_cache_buffer(self):
|
||||
print_verbose(
|
||||
|
@ -274,40 +490,17 @@ class RedisCache(BaseCache):
|
|||
traceback.print_exc()
|
||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
async with _redis_client as redis_client:
|
||||
try:
|
||||
print_verbose(f"Get Async Redis Cache: key: {key}")
|
||||
cached_response = await redis_client.get(key)
|
||||
print_verbose(
|
||||
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
response = self._get_cache_logic(cached_response=cached_response)
|
||||
return response
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
print_verbose(
|
||||
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_get_cache_pipeline(self, key_list) -> dict:
|
||||
def batch_get_cache(self, key_list) -> dict:
|
||||
"""
|
||||
Use Redis for bulk read operations
|
||||
"""
|
||||
_redis_client = await self.init_async_client()
|
||||
key_value_dict = {}
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
# Queue the get operations in the pipeline for all keys.
|
||||
for cache_key in key_list:
|
||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||
pipe.get(cache_key) # Queue GET command in pipeline
|
||||
|
||||
# Execute the pipeline and await the results.
|
||||
results = await pipe.execute()
|
||||
_keys = []
|
||||
for cache_key in key_list:
|
||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||
_keys.append(cache_key)
|
||||
results = self.redis_client.mget(keys=_keys)
|
||||
|
||||
# Associate the results back with their keys.
|
||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||
|
@ -323,21 +516,185 @@ class RedisCache(BaseCache):
|
|||
print_verbose(f"Error occurred in pipeline read - {str(e)}")
|
||||
return key_value_dict
|
||||
|
||||
async def ping(self):
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
start_time = time.time()
|
||||
async with _redis_client as redis_client:
|
||||
try:
|
||||
print_verbose(f"Get Async Redis Cache: key: {key}")
|
||||
cached_response = await redis_client.get(key)
|
||||
print_verbose(
|
||||
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
response = self._get_cache_logic(cached_response=cached_response)
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="async_get_cache",
|
||||
)
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="async_get_cache",
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
print_verbose(
|
||||
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_batch_get_cache(self, key_list) -> dict:
|
||||
"""
|
||||
Use Redis for bulk read operations
|
||||
"""
|
||||
_redis_client = await self.init_async_client()
|
||||
key_value_dict = {}
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
_keys = []
|
||||
for cache_key in key_list:
|
||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||
_keys.append(cache_key)
|
||||
results = await redis_client.mget(keys=_keys)
|
||||
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="async_batch_get_cache",
|
||||
)
|
||||
)
|
||||
|
||||
# Associate the results back with their keys.
|
||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||
key_value_dict = dict(zip(key_list, results))
|
||||
|
||||
decoded_results = {}
|
||||
for k, v in key_value_dict.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode("utf-8")
|
||||
v = self._get_cache_logic(v)
|
||||
decoded_results[k] = v
|
||||
|
||||
return decoded_results
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="async_batch_get_cache",
|
||||
)
|
||||
)
|
||||
print_verbose(f"Error occurred in pipeline read - {str(e)}")
|
||||
return key_value_dict
|
||||
|
||||
def sync_ping(self) -> bool:
|
||||
"""
|
||||
Tests if the sync redis client is correctly setup.
|
||||
"""
|
||||
print_verbose(f"Pinging Sync Redis Cache")
|
||||
start_time = time.time()
|
||||
try:
|
||||
response = self.redis_client.ping()
|
||||
print_verbose(f"Redis Cache PING: {response}")
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
self.service_logger_obj.service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="sync_ping",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
self.service_logger_obj.service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="sync_ping",
|
||||
)
|
||||
print_verbose(
|
||||
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def ping(self) -> bool:
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
async with _redis_client as redis_client:
|
||||
print_verbose(f"Pinging Async Redis Cache")
|
||||
try:
|
||||
response = await redis_client.ping()
|
||||
print_verbose(f"Redis Cache PING: {response}")
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="async_ping",
|
||||
)
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="async_ping",
|
||||
)
|
||||
)
|
||||
print_verbose(
|
||||
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
_redis_client = self.init_async_client()
|
||||
# keys is a list, unpack it so it gets passed as individual elements to delete
|
||||
async with _redis_client as redis_client:
|
||||
await redis_client.delete(*keys)
|
||||
|
||||
def client_list(self):
|
||||
client_list = self.redis_client.client_list()
|
||||
return client_list
|
||||
|
||||
def info(self):
|
||||
info = self.redis_client.info()
|
||||
return info
|
||||
|
||||
def flush_cache(self):
|
||||
self.redis_client.flushall()
|
||||
|
||||
|
@ -828,8 +1185,10 @@ class DualCache(BaseCache):
|
|||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
self.default_in_memory_ttl = default_in_memory_ttl
|
||||
self.default_redis_ttl = default_redis_ttl
|
||||
self.default_in_memory_ttl = (
|
||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||
)
|
||||
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
|
||||
|
||||
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
|
@ -846,6 +1205,30 @@ class DualCache(BaseCache):
|
|||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def increment_cache(
|
||||
self, key, value: int, local_only: bool = False, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - int - the value you want to increment by
|
||||
|
||||
Returns - int - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: int = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def get_cache(self, key, local_only: bool = False, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
|
@ -872,6 +1255,39 @@ class DualCache(BaseCache):
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
|
||||
try:
|
||||
result = [None for _ in range(len(keys))]
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
|
||||
|
||||
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if None in result and self.redis_cache is not None and local_only == False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
"""
|
||||
sublist_keys = [
|
||||
key for key, value in zip(keys, result) if value is None
|
||||
]
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
for key in redis_result:
|
||||
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
|
||||
|
||||
for key, value in redis_result.items():
|
||||
result[keys.index(key)] = value
|
||||
|
||||
print_verbose(f"async batch get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
|
@ -905,7 +1321,50 @@ class DualCache(BaseCache):
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self, keys: list, local_only: bool = False, **kwargs
|
||||
):
|
||||
try:
|
||||
result = [None for _ in range(len(keys))]
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
||||
keys, **kwargs
|
||||
)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
if None in result and self.redis_cache is not None and local_only == False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
"""
|
||||
sublist_keys = [
|
||||
key for key, value in zip(keys, result) if value is None
|
||||
]
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||
sublist_keys, **kwargs
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
for key, value in redis_result.items():
|
||||
if value is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result[key], **kwargs
|
||||
)
|
||||
for key, value in redis_result.items():
|
||||
index = keys.index(key)
|
||||
result[index] = value
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
print_verbose(
|
||||
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
@ -916,6 +1375,32 @@ class DualCache(BaseCache):
|
|||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_increment_cache(
|
||||
self, key, value: int, local_only: bool = False, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - int - the value you want to increment by
|
||||
|
||||
Returns - int - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: int = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
result = await self.redis_cache.async_increment(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
|
@ -939,6 +1424,8 @@ class Cache:
|
|||
password: Optional[str] = None,
|
||||
namespace: Optional[str] = None,
|
||||
ttl: Optional[float] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_in_redis_ttl: Optional[float] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
supported_call_types: Optional[
|
||||
List[
|
||||
|
@ -1038,6 +1525,14 @@ class Cache:
|
|||
self.redis_flush_size = redis_flush_size
|
||||
self.ttl = ttl
|
||||
|
||||
if self.type == "local" and default_in_memory_ttl is not None:
|
||||
self.ttl = default_in_memory_ttl
|
||||
|
||||
if (
|
||||
self.type == "redis" or self.type == "redis-semantic"
|
||||
) and default_in_redis_ttl is not None:
|
||||
self.ttl = default_in_redis_ttl
|
||||
|
||||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
|
@ -1379,6 +1874,11 @@ class Cache:
|
|||
return await self.cache.ping()
|
||||
return None
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
if hasattr(self.cache, "delete_cache_keys"):
|
||||
return await self.cache.delete_cache_keys(keys)
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
|
|
@ -82,14 +82,18 @@ class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
|||
|
||||
class Timeout(APITimeoutError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
self.status_code = 408
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(
|
||||
request=request
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
self.status_code = 408
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
|
||||
# custom function to convert to str
|
||||
def __str__(self):
|
||||
return str(self.message)
|
||||
|
||||
|
||||
class PermissionDeniedError(PermissionDeniedError): # type:ignore
|
||||
|
|
|
@ -6,7 +6,7 @@ import requests
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Union, Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### PRE-CALL CHECKS - router/proxy only ####
|
||||
"""
|
||||
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
|
||||
"""
|
||||
|
||||
async def async_pre_call_check(self, deployment: dict) -> Optional[dict]:
|
||||
pass
|
||||
|
||||
def pre_call_check(self, deployment: dict) -> Optional[dict]:
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
"""
|
||||
Control the modify incoming / outgoung data before calling the model
|
||||
|
|
51
litellm/integrations/greenscale.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import requests
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class GreenscaleLogger:
|
||||
def __init__(self):
|
||||
import os
|
||||
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
|
||||
self.headers = {
|
||||
"api-key": self.greenscale_api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
try:
|
||||
response_json = response_obj.model_dump() if response_obj else {}
|
||||
data = {
|
||||
"modelId": kwargs.get("model"),
|
||||
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
|
||||
"outputTokenCount": response_json.get("usage", {}).get("completion_tokens"),
|
||||
}
|
||||
data["timestamp"] = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
if type(end_time) == datetime and type(start_time) == datetime:
|
||||
data["invocationLatency"] = int((end_time - start_time).total_seconds() * 1000)
|
||||
|
||||
|
||||
# Add additional metadata keys to tags
|
||||
tags = []
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
for key, value in metadata.items():
|
||||
if key.startswith("greenscale"):
|
||||
if key == "greenscale_project":
|
||||
data["project"] = value
|
||||
elif key == "greenscale_application":
|
||||
data["application"] = value
|
||||
else:
|
||||
tags.append({"key": key.replace("greenscale_", ""), "value": str(value)})
|
||||
|
||||
data["tags"] = tags
|
||||
|
||||
response = requests.post(self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str))
|
||||
if response.status_code != 200:
|
||||
print_verbose(f"Greenscale Logger Error - {response.text}, {response.status_code}")
|
||||
else:
|
||||
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
|
||||
except Exception as e:
|
||||
print_verbose(f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}")
|
||||
pass
|
|
@ -17,7 +17,7 @@ class LangFuseLogger:
|
|||
from langfuse import Langfuse
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\033[0m"
|
||||
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||
)
|
||||
# Instance variables
|
||||
self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
|
||||
|
@ -34,6 +34,14 @@ class LangFuseLogger:
|
|||
flush_interval=1, # flush interval in seconds
|
||||
)
|
||||
|
||||
# set the current langfuse project id in the environ
|
||||
# this is used by Alerting to link to the correct project
|
||||
try:
|
||||
project_id = self.Langfuse.client.projects.get().data[0].id
|
||||
os.environ["LANGFUSE_PROJECT_ID"] = project_id
|
||||
except:
|
||||
project_id = None
|
||||
|
||||
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
|
||||
self.upstream_langfuse_secret_key = os.getenv(
|
||||
"UPSTREAM_LANGFUSE_SECRET_KEY"
|
||||
|
@ -76,6 +84,7 @@ class LangFuseLogger:
|
|||
print_verbose(
|
||||
f"Langfuse Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
|
@ -118,6 +127,11 @@ class LangFuseLogger:
|
|||
):
|
||||
input = prompt
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.choices[0].text
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
|
@ -128,6 +142,7 @@ class LangFuseLogger:
|
|||
self._log_langfuse_v2(
|
||||
user_id,
|
||||
metadata,
|
||||
litellm_params,
|
||||
output,
|
||||
start_time,
|
||||
end_time,
|
||||
|
@ -156,7 +171,7 @@ class LangFuseLogger:
|
|||
verbose_logger.info(f"Langfuse Layer Logging - logging success")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
async def _async_log_event(
|
||||
|
@ -185,7 +200,7 @@ class LangFuseLogger:
|
|||
):
|
||||
from langfuse.model import CreateTrace, CreateGeneration
|
||||
|
||||
print(
|
||||
verbose_logger.warning(
|
||||
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
|
||||
)
|
||||
|
||||
|
@ -219,6 +234,7 @@ class LangFuseLogger:
|
|||
self,
|
||||
user_id,
|
||||
metadata,
|
||||
litellm_params,
|
||||
output,
|
||||
start_time,
|
||||
end_time,
|
||||
|
@ -273,13 +289,13 @@ class LangFuseLogger:
|
|||
clean_metadata = {}
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# generate langfuse tags
|
||||
if key in [
|
||||
"user_api_key",
|
||||
"user_api_key_user_id",
|
||||
"user_api_key_team_id",
|
||||
"semantic-similarity",
|
||||
]:
|
||||
|
||||
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||
if (
|
||||
litellm._langfuse_default_tags is not None
|
||||
and isinstance(litellm._langfuse_default_tags, list)
|
||||
and key in litellm._langfuse_default_tags
|
||||
):
|
||||
tags.append(f"{key}:{value}")
|
||||
|
||||
# clean litellm metadata before logging
|
||||
|
@ -293,13 +309,55 @@ class LangFuseLogger:
|
|||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
if (
|
||||
litellm._langfuse_default_tags is not None
|
||||
and isinstance(litellm._langfuse_default_tags, list)
|
||||
and "proxy_base_url" in litellm._langfuse_default_tags
|
||||
):
|
||||
proxy_base_url = os.environ.get("PROXY_BASE_URL", None)
|
||||
if proxy_base_url is not None:
|
||||
tags.append(f"proxy_base_url:{proxy_base_url}")
|
||||
|
||||
api_base = litellm_params.get("api_base", None)
|
||||
if api_base:
|
||||
clean_metadata["api_base"] = api_base
|
||||
|
||||
vertex_location = kwargs.get("vertex_location", None)
|
||||
if vertex_location:
|
||||
clean_metadata["vertex_location"] = vertex_location
|
||||
|
||||
aws_region_name = kwargs.get("aws_region_name", None)
|
||||
if aws_region_name:
|
||||
clean_metadata["aws_region_name"] = aws_region_name
|
||||
|
||||
if supports_tags:
|
||||
if "cache_hit" in kwargs:
|
||||
if kwargs["cache_hit"] is None:
|
||||
kwargs["cache_hit"] = False
|
||||
tags.append(f"cache_hit:{kwargs['cache_hit']}")
|
||||
clean_metadata["cache_hit"] = kwargs["cache_hit"]
|
||||
trace_params.update({"tags": tags})
|
||||
|
||||
proxy_server_request = litellm_params.get("proxy_server_request", None)
|
||||
if proxy_server_request:
|
||||
method = proxy_server_request.get("method", None)
|
||||
url = proxy_server_request.get("url", None)
|
||||
headers = proxy_server_request.get("headers", None)
|
||||
clean_headers = {}
|
||||
if headers:
|
||||
for key, value in headers.items():
|
||||
# these headers can leak our API keys and/or JWT tokens
|
||||
if key.lower() not in ["authorization", "cookie", "referer"]:
|
||||
clean_headers[key] = value
|
||||
|
||||
clean_metadata["request"] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": clean_headers,
|
||||
}
|
||||
|
||||
print_verbose(f"trace_params: {trace_params}")
|
||||
|
||||
trace = self.Langfuse.trace(**trace_params)
|
||||
|
||||
generation_id = None
|
||||
|
@ -316,13 +374,21 @@ class LangFuseLogger:
|
|||
# just log `litellm-{call_type}` as the generation name
|
||||
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||
|
||||
if response_obj is not None and "system_fingerprint" in response_obj:
|
||||
system_fingerprint = response_obj.get("system_fingerprint", None)
|
||||
else:
|
||||
system_fingerprint = None
|
||||
|
||||
if system_fingerprint is not None:
|
||||
optional_params["system_fingerprint"] = system_fingerprint
|
||||
|
||||
generation_params = {
|
||||
"name": generation_name,
|
||||
"id": metadata.get("generation_id", generation_id),
|
||||
"startTime": start_time,
|
||||
"endTime": end_time,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"model": kwargs["model"],
|
||||
"modelParameters": optional_params,
|
||||
"model_parameters": optional_params,
|
||||
"input": input,
|
||||
"output": output,
|
||||
"usage": usage,
|
||||
|
@ -334,13 +400,15 @@ class LangFuseLogger:
|
|||
generation_params["prompt"] = metadata.get("prompt", None)
|
||||
|
||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||
generation_params["statusMessage"] = output
|
||||
generation_params["status_message"] = output
|
||||
|
||||
if supports_completion_start_time:
|
||||
generation_params["completion_start_time"] = kwargs.get(
|
||||
"completion_start_time", None
|
||||
)
|
||||
|
||||
print_verbose(f"generation_params: {generation_params}")
|
||||
|
||||
trace.generation(**generation_params)
|
||||
except Exception as e:
|
||||
print(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
|
|
|
@ -7,6 +7,19 @@ from datetime import datetime
|
|||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import asyncio
|
||||
import types
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def is_serializable(value):
|
||||
non_serializable_types = (
|
||||
types.CoroutineType,
|
||||
types.FunctionType,
|
||||
types.GeneratorType,
|
||||
BaseModel,
|
||||
)
|
||||
return not isinstance(value, non_serializable_types)
|
||||
|
||||
|
||||
class LangsmithLogger:
|
||||
|
@ -21,7 +34,9 @@ class LangsmithLogger:
|
|||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
# Method definition
|
||||
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
|
||||
metadata = kwargs.get('litellm_params', {}).get("metadata", {}) or {} # if metadata is None
|
||||
metadata = (
|
||||
kwargs.get("litellm_params", {}).get("metadata", {}) or {}
|
||||
) # if metadata is None
|
||||
|
||||
# set project name and run_name for langsmith logging
|
||||
# users can pass project_name and run name to litellm.completion()
|
||||
|
@ -51,24 +66,46 @@ class LangsmithLogger:
|
|||
new_kwargs = {}
|
||||
for key in kwargs:
|
||||
value = kwargs[key]
|
||||
if key == "start_time" or key == "end_time":
|
||||
if key == "start_time" or key == "end_time" or value is None:
|
||||
pass
|
||||
elif type(value) != dict:
|
||||
elif type(value) == datetime.datetime:
|
||||
new_kwargs[key] = value.isoformat()
|
||||
elif type(value) != dict and is_serializable(value=value):
|
||||
new_kwargs[key] = value
|
||||
|
||||
requests.post(
|
||||
print(f"type of response: {type(response_obj)}")
|
||||
for k, v in new_kwargs.items():
|
||||
print(f"key={k}, type of arg: {type(v)}, value={v}")
|
||||
|
||||
if isinstance(response_obj, BaseModel):
|
||||
try:
|
||||
response_obj = response_obj.model_dump()
|
||||
except:
|
||||
response_obj = response_obj.dict() # type: ignore
|
||||
|
||||
print(f"response_obj: {response_obj}")
|
||||
|
||||
data = {
|
||||
"name": run_name,
|
||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||
"inputs": new_kwargs,
|
||||
"outputs": response_obj,
|
||||
"session_name": project_name,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
print(f"data: {data}")
|
||||
|
||||
response = requests.post(
|
||||
"https://api.smith.langchain.com/runs",
|
||||
json={
|
||||
"name": run_name,
|
||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||
"inputs": {**new_kwargs},
|
||||
"outputs": response_obj.json(),
|
||||
"session_name": project_name,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
json=data,
|
||||
headers={"x-api-key": self.langsmith_api_key},
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
print_verbose(f"Error: {response.status_code}")
|
||||
else:
|
||||
print_verbose("Run successfully created")
|
||||
print_verbose(
|
||||
f"Langsmith Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
|
|
|
@ -4,11 +4,13 @@ from datetime import datetime, timezone
|
|||
import traceback
|
||||
import dotenv
|
||||
import importlib
|
||||
from pkg_resources import parse_version
|
||||
import sys
|
||||
|
||||
import packaging
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# convert to {completion: xx, tokens: xx}
|
||||
def parse_usage(usage):
|
||||
return {
|
||||
|
@ -16,6 +18,7 @@ def parse_usage(usage):
|
|||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||
}
|
||||
|
||||
|
||||
def parse_messages(input):
|
||||
if input is None:
|
||||
return None
|
||||
|
@ -28,7 +31,6 @@ def parse_messages(input):
|
|||
if "message" in message:
|
||||
return clean_message(message["message"])
|
||||
|
||||
|
||||
serialized = {
|
||||
"role": message.get("role"),
|
||||
"content": message.get("content"),
|
||||
|
@ -56,10 +58,13 @@ class LunaryLogger:
|
|||
def __init__(self):
|
||||
try:
|
||||
import lunary
|
||||
|
||||
version = importlib.metadata.version("lunary")
|
||||
# if version < 0.1.43 then raise ImportError
|
||||
if parse_version(version) < parse_version("0.1.43"):
|
||||
print("Lunary version outdated. Required: > 0.1.43. Upgrade via 'pip install lunary --upgrade'")
|
||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
||||
print(
|
||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||
)
|
||||
raise ImportError
|
||||
|
||||
self.lunary_client = lunary
|
||||
|
@ -88,9 +93,7 @@ class LunaryLogger:
|
|||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
)
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
|
||||
tags = litellm_params.pop("tags", None) or []
|
||||
|
||||
|
@ -148,7 +151,7 @@ class LunaryLogger:
|
|||
runtime="litellm",
|
||||
error=error_obj,
|
||||
output=parse_messages(output),
|
||||
token_usage=usage
|
||||
token_usage=usage,
|
||||
)
|
||||
|
||||
except:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# used for /metrics endpoint on LiteLLM Proxy
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
# On success, log events to Prometheus
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
@ -19,27 +19,33 @@ class PrometheusLogger:
|
|||
**kwargs,
|
||||
):
|
||||
try:
|
||||
verbose_logger.debug(f"in init prometheus metrics")
|
||||
print(f"in init prometheus metrics")
|
||||
from prometheus_client import Counter
|
||||
|
||||
self.litellm_llm_api_failed_requests_metric = Counter(
|
||||
name="litellm_llm_api_failed_requests_metric",
|
||||
documentation="Total number of failed LLM API calls via litellm",
|
||||
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
|
||||
)
|
||||
|
||||
self.litellm_requests_metric = Counter(
|
||||
name="litellm_requests_metric",
|
||||
documentation="Total number of LLM calls to litellm",
|
||||
labelnames=["user", "key", "model"],
|
||||
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
|
||||
)
|
||||
|
||||
# Counter for spend
|
||||
self.litellm_spend_metric = Counter(
|
||||
"litellm_spend_metric",
|
||||
"Total spend on LLM requests",
|
||||
labelnames=["user", "key", "model"],
|
||||
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
|
||||
)
|
||||
|
||||
# Counter for total_output_tokens
|
||||
self.litellm_tokens_metric = Counter(
|
||||
"litellm_total_tokens",
|
||||
"Total number of input + output tokens from LLM requests",
|
||||
labelnames=["user", "key", "model"],
|
||||
labelnames=["end_user", "hashed_api_key", "model", "team", "user"],
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||
|
@ -61,24 +67,50 @@ class PrometheusLogger:
|
|||
|
||||
# unpack kwargs
|
||||
model = kwargs.get("model", "")
|
||||
response_cost = kwargs.get("response_cost", 0.0)
|
||||
response_cost = kwargs.get("response_cost", 0.0) or 0
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = litellm_params.get("metadata", {}).get(
|
||||
"user_api_key_user_id", None
|
||||
)
|
||||
user_api_key = litellm_params.get("metadata", {}).get("user_api_key", None)
|
||||
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0)
|
||||
user_api_team = litellm_params.get("metadata", {}).get(
|
||||
"user_api_key_team_id", None
|
||||
)
|
||||
if response_obj is not None:
|
||||
tokens_used = response_obj.get("usage", {}).get("total_tokens", 0)
|
||||
else:
|
||||
tokens_used = 0
|
||||
|
||||
print_verbose(
|
||||
f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}"
|
||||
)
|
||||
|
||||
self.litellm_requests_metric.labels(end_user_id, user_api_key, model).inc()
|
||||
self.litellm_spend_metric.labels(end_user_id, user_api_key, model).inc(
|
||||
response_cost
|
||||
)
|
||||
self.litellm_tokens_metric.labels(end_user_id, user_api_key, model).inc(
|
||||
tokens_used
|
||||
)
|
||||
if (
|
||||
user_api_key is not None
|
||||
and isinstance(user_api_key, str)
|
||||
and user_api_key.startswith("sk-")
|
||||
):
|
||||
from litellm.proxy.utils import hash_token
|
||||
|
||||
user_api_key = hash_token(user_api_key)
|
||||
|
||||
self.litellm_requests_metric.labels(
|
||||
end_user_id, user_api_key, model, user_api_team, user_id
|
||||
).inc()
|
||||
self.litellm_spend_metric.labels(
|
||||
end_user_id, user_api_key, model, user_api_team, user_id
|
||||
).inc(response_cost)
|
||||
self.litellm_tokens_metric.labels(
|
||||
end_user_id, user_api_key, model, user_api_team, user_id
|
||||
).inc(tokens_used)
|
||||
|
||||
### FAILURE INCREMENT ###
|
||||
if "exception" in kwargs:
|
||||
self.litellm_llm_api_failed_requests_metric.labels(
|
||||
end_user_id, user_api_key, model, user_api_team, user_id
|
||||
).inc()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
verbose_logger.debug(
|
||||
|
|
198
litellm/integrations/prometheus_services.py
Normal file
|
@ -0,0 +1,198 @@
|
|||
# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
|
||||
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
||||
|
||||
class PrometheusServicesLogger:
|
||||
# Class variables or attributes
|
||||
litellm_service_latency = None # Class-level attribute to store the Histogram
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mock_testing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
try:
|
||||
from prometheus_client import Counter, Histogram, REGISTRY
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Missing prometheus_client. Run `pip install prometheus-client`"
|
||||
)
|
||||
|
||||
self.Histogram = Histogram
|
||||
self.Counter = Counter
|
||||
self.REGISTRY = REGISTRY
|
||||
|
||||
verbose_logger.debug(f"in init prometheus services metrics")
|
||||
|
||||
self.services = [item.value for item in ServiceTypes]
|
||||
|
||||
self.payload_to_prometheus_map = (
|
||||
{}
|
||||
) # store the prometheus histogram/counter we need to call for each field in payload
|
||||
|
||||
for service in self.services:
|
||||
histogram = self.create_histogram(service, type_of_request="latency")
|
||||
counter_failed_request = self.create_counter(
|
||||
service, type_of_request="failed_requests"
|
||||
)
|
||||
counter_total_requests = self.create_counter(
|
||||
service, type_of_request="total_requests"
|
||||
)
|
||||
self.payload_to_prometheus_map[service] = [
|
||||
histogram,
|
||||
counter_failed_request,
|
||||
counter_total_requests,
|
||||
]
|
||||
|
||||
self.prometheus_to_amount_map: dict = (
|
||||
{}
|
||||
) # the field / value in ServiceLoggerPayload the object needs to be incremented by
|
||||
|
||||
### MOCK TESTING ###
|
||||
self.mock_testing = mock_testing
|
||||
self.mock_testing_success_calls = 0
|
||||
self.mock_testing_failure_calls = 0
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||
raise e
|
||||
|
||||
def is_metric_registered(self, metric_name) -> bool:
|
||||
for metric in self.REGISTRY.collect():
|
||||
if metric_name == metric.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_metric(self, metric_name):
|
||||
for metric in self.REGISTRY.collect():
|
||||
for sample in metric.samples:
|
||||
if metric_name == sample.name:
|
||||
return metric
|
||||
return None
|
||||
|
||||
def create_histogram(self, service: str, type_of_request: str):
|
||||
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
||||
is_registered = self.is_metric_registered(metric_name)
|
||||
if is_registered:
|
||||
return self.get_metric(metric_name)
|
||||
return self.Histogram(
|
||||
metric_name,
|
||||
"Latency for {} service".format(service),
|
||||
labelnames=[service],
|
||||
)
|
||||
|
||||
def create_counter(self, service: str, type_of_request: str):
|
||||
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
||||
is_registered = self.is_metric_registered(metric_name)
|
||||
if is_registered:
|
||||
return self.get_metric(metric_name)
|
||||
return self.Counter(
|
||||
metric_name,
|
||||
"Total {} for {} service".format(type_of_request, service),
|
||||
labelnames=[service],
|
||||
)
|
||||
|
||||
def observe_histogram(
|
||||
self,
|
||||
histogram,
|
||||
labels: str,
|
||||
amount: float,
|
||||
):
|
||||
assert isinstance(histogram, self.Histogram)
|
||||
|
||||
histogram.labels(labels).observe(amount)
|
||||
|
||||
def increment_counter(
|
||||
self,
|
||||
counter,
|
||||
labels: str,
|
||||
amount: float,
|
||||
):
|
||||
assert isinstance(counter, self.Counter)
|
||||
|
||||
counter.labels(labels).inc(amount)
|
||||
|
||||
def service_success_hook(self, payload: ServiceLoggerPayload):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_success_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
if isinstance(obj, self.Histogram):
|
||||
self.observe_histogram(
|
||||
histogram=obj,
|
||||
labels=payload.service.value,
|
||||
amount=payload.duration,
|
||||
)
|
||||
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
||||
)
|
||||
|
||||
def service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_failure_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
if isinstance(obj, self.Counter):
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
amount=1, # LOG ERROR COUNT / TOTAL REQUESTS TO PROMETHEUS
|
||||
)
|
||||
|
||||
async def async_service_success_hook(self, payload: ServiceLoggerPayload):
|
||||
"""
|
||||
Log successful call to prometheus
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_success_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
if isinstance(obj, self.Histogram):
|
||||
self.observe_histogram(
|
||||
histogram=obj,
|
||||
labels=payload.service.value,
|
||||
amount=payload.duration,
|
||||
)
|
||||
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
||||
)
|
||||
|
||||
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||
print(f"received error payload: {payload.error}")
|
||||
if self.mock_testing:
|
||||
self.mock_testing_failure_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
if isinstance(obj, self.Counter):
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
amount=1, # LOG ERROR COUNT TO PROMETHEUS
|
||||
)
|
486
litellm/integrations/slack_alerting.py
Normal file
|
@ -0,0 +1,486 @@
|
|||
#### What this does ####
|
||||
# Class for sending Slack Alerts #
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import copy
|
||||
import traceback
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
import litellm
|
||||
from typing import List, Literal, Any, Union, Optional, Dict
|
||||
from litellm.caching import DualCache
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
|
||||
class SlackAlerting:
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
alerting_threshold: float = 300,
|
||||
alerting: Optional[List] = [],
|
||||
alert_types: Optional[
|
||||
List[
|
||||
Literal[
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
]
|
||||
]
|
||||
] = [
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
],
|
||||
alert_to_webhook_url: Optional[
|
||||
Dict
|
||||
] = None, # if user wants to separate alerts to diff channels
|
||||
):
|
||||
self.alerting_threshold = alerting_threshold
|
||||
self.alerting = alerting
|
||||
self.alert_types = alert_types
|
||||
self.internal_usage_cache = DualCache()
|
||||
self.async_http_handler = AsyncHTTPHandler()
|
||||
self.alert_to_webhook_url = alert_to_webhook_url
|
||||
|
||||
pass
|
||||
|
||||
def update_values(
|
||||
self,
|
||||
alerting: Optional[List] = None,
|
||||
alerting_threshold: Optional[float] = None,
|
||||
alert_types: Optional[List] = None,
|
||||
alert_to_webhook_url: Optional[Dict] = None,
|
||||
):
|
||||
if alerting is not None:
|
||||
self.alerting = alerting
|
||||
if alerting_threshold is not None:
|
||||
self.alerting_threshold = alerting_threshold
|
||||
if alert_types is not None:
|
||||
self.alert_types = alert_types
|
||||
|
||||
if alert_to_webhook_url is not None:
|
||||
# update the dict
|
||||
if self.alert_to_webhook_url is None:
|
||||
self.alert_to_webhook_url = alert_to_webhook_url
|
||||
else:
|
||||
self.alert_to_webhook_url.update(alert_to_webhook_url)
|
||||
|
||||
async def deployment_in_cooldown(self):
|
||||
pass
|
||||
|
||||
async def deployment_removed_from_cooldown(self):
|
||||
pass
|
||||
|
||||
def _all_possible_alert_types(self):
|
||||
# used by the UI to show all supported alert types
|
||||
# Note: This is not the alerts the user has configured, instead it's all possible alert types a user can select
|
||||
return [
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
]
|
||||
|
||||
def _add_langfuse_trace_id_to_alert(
|
||||
self,
|
||||
request_info: str,
|
||||
request_data: Optional[dict] = None,
|
||||
kwargs: Optional[dict] = None,
|
||||
):
|
||||
import uuid
|
||||
|
||||
# For now: do nothing as we're debugging why this is not working as expected
|
||||
return request_info
|
||||
|
||||
# if request_data is not None:
|
||||
# trace_id = request_data.get("metadata", {}).get(
|
||||
# "trace_id", None
|
||||
# ) # get langfuse trace id
|
||||
# if trace_id is None:
|
||||
# trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
|
||||
# request_data["metadata"]["trace_id"] = trace_id
|
||||
# elif kwargs is not None:
|
||||
# _litellm_params = kwargs.get("litellm_params", {})
|
||||
# trace_id = _litellm_params.get("metadata", {}).get(
|
||||
# "trace_id", None
|
||||
# ) # get langfuse trace id
|
||||
# if trace_id is None:
|
||||
# trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
|
||||
# _litellm_params["metadata"]["trace_id"] = trace_id
|
||||
|
||||
# _langfuse_host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
|
||||
# _langfuse_project_id = os.environ.get("LANGFUSE_PROJECT_ID")
|
||||
|
||||
# # langfuse urls look like: https://us.cloud.langfuse.com/project/************/traces/litellm-alert-trace-ididi9dk-09292-************
|
||||
|
||||
# _langfuse_url = (
|
||||
# f"{_langfuse_host}/project/{_langfuse_project_id}/traces/{trace_id}"
|
||||
# )
|
||||
# request_info += f"\n🪢 Langfuse Trace: {_langfuse_url}"
|
||||
# return request_info
|
||||
|
||||
def _response_taking_too_long_callback(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
try:
|
||||
time_difference = end_time - start_time
|
||||
# Convert the timedelta to float (in seconds)
|
||||
time_difference_float = time_difference.total_seconds()
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
model = kwargs.get("model", "")
|
||||
api_base = litellm.get_api_base(model=model, optional_params=litellm_params)
|
||||
messages = kwargs.get("messages", None)
|
||||
# if messages does not exist fallback to "input"
|
||||
if messages is None:
|
||||
messages = kwargs.get("input", None)
|
||||
|
||||
# only use first 100 chars for alerting
|
||||
_messages = str(messages)[:100]
|
||||
|
||||
return time_difference_float, model, api_base, _messages
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _get_deployment_latencies_to_alert(self, metadata=None):
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
if "_latency_per_deployment" in metadata:
|
||||
# Translate model_id to -> api_base
|
||||
# _latency_per_deployment is a dictionary that looks like this:
|
||||
"""
|
||||
_latency_per_deployment: {
|
||||
api_base: 0.01336697916666667
|
||||
}
|
||||
"""
|
||||
_message_to_send = ""
|
||||
_deployment_latencies = metadata["_latency_per_deployment"]
|
||||
if len(_deployment_latencies) == 0:
|
||||
return None
|
||||
for api_base, latency in _deployment_latencies.items():
|
||||
_message_to_send += f"\n{api_base}: {round(latency,2)}s"
|
||||
_message_to_send = "```" + _message_to_send + "```"
|
||||
return _message_to_send
|
||||
|
||||
async def response_taking_too_long_callback(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
completion_response, # response from completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
if self.alerting is None or self.alert_types is None:
|
||||
return
|
||||
|
||||
time_difference_float, model, api_base, messages = (
|
||||
self._response_taking_too_long_callback(
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||
if time_difference_float > self.alerting_threshold:
|
||||
if "langfuse" in litellm.success_callback:
|
||||
request_info = self._add_langfuse_trace_id_to_alert(
|
||||
request_info=request_info, kwargs=kwargs
|
||||
)
|
||||
# add deployment latencies to alert
|
||||
if (
|
||||
kwargs is not None
|
||||
and "litellm_params" in kwargs
|
||||
and "metadata" in kwargs["litellm_params"]
|
||||
):
|
||||
_metadata = kwargs["litellm_params"]["metadata"]
|
||||
|
||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||
metadata=_metadata
|
||||
)
|
||||
if _deployment_latency_map is not None:
|
||||
request_info += (
|
||||
f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
|
||||
)
|
||||
await self.send_alert(
|
||||
message=slow_message + request_info,
|
||||
level="Low",
|
||||
alert_type="llm_too_slow",
|
||||
)
|
||||
|
||||
async def log_failure_event(self, original_exception: Exception):
|
||||
pass
|
||||
|
||||
async def response_taking_too_long(
|
||||
self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
type: Literal["hanging_request", "slow_response"] = "hanging_request",
|
||||
request_data: Optional[dict] = None,
|
||||
):
|
||||
if self.alerting is None or self.alert_types is None:
|
||||
return
|
||||
if request_data is not None:
|
||||
model = request_data.get("model", "")
|
||||
messages = request_data.get("messages", None)
|
||||
if messages is None:
|
||||
# if messages does not exist fallback to "input"
|
||||
messages = request_data.get("input", None)
|
||||
|
||||
# try casting messages to str and get the first 100 characters, else mark as None
|
||||
try:
|
||||
messages = str(messages)
|
||||
messages = messages[:100]
|
||||
except:
|
||||
messages = ""
|
||||
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
|
||||
if "langfuse" in litellm.success_callback:
|
||||
request_info = self._add_langfuse_trace_id_to_alert(
|
||||
request_info=request_info, request_data=request_data
|
||||
)
|
||||
else:
|
||||
request_info = ""
|
||||
|
||||
if type == "hanging_request":
|
||||
await asyncio.sleep(
|
||||
self.alerting_threshold
|
||||
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
|
||||
if (
|
||||
request_data is not None
|
||||
and request_data.get("litellm_status", "") != "success"
|
||||
and request_data.get("litellm_status", "") != "fail"
|
||||
):
|
||||
if request_data.get("deployment", None) is not None and isinstance(
|
||||
request_data["deployment"], dict
|
||||
):
|
||||
_api_base = litellm.get_api_base(
|
||||
model=model,
|
||||
optional_params=request_data["deployment"].get(
|
||||
"litellm_params", {}
|
||||
),
|
||||
)
|
||||
|
||||
if _api_base is None:
|
||||
_api_base = ""
|
||||
|
||||
request_info += f"\nAPI Base: {_api_base}"
|
||||
elif request_data.get("metadata", None) is not None and isinstance(
|
||||
request_data["metadata"], dict
|
||||
):
|
||||
# In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
|
||||
# in that case we fallback to the api base set in the request metadata
|
||||
_metadata = request_data["metadata"]
|
||||
_api_base = _metadata.get("api_base", "")
|
||||
if _api_base is None:
|
||||
_api_base = ""
|
||||
request_info += f"\nAPI Base: `{_api_base}`"
|
||||
# only alert hanging responses if they have not been marked as success
|
||||
alerting_message = (
|
||||
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
|
||||
)
|
||||
|
||||
# add deployment latencies to alert
|
||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||
metadata=request_data.get("metadata", {})
|
||||
)
|
||||
if _deployment_latency_map is not None:
|
||||
request_info += f"\nDeployment Latencies\n{_deployment_latency_map}"
|
||||
|
||||
await self.send_alert(
|
||||
message=alerting_message + request_info,
|
||||
level="Medium",
|
||||
alert_type="llm_requests_hanging",
|
||||
)
|
||||
|
||||
async def budget_alerts(
|
||||
self,
|
||||
type: Literal[
|
||||
"token_budget",
|
||||
"user_budget",
|
||||
"user_and_proxy_budget",
|
||||
"failed_budgets",
|
||||
"failed_tracking",
|
||||
"projected_limit_exceeded",
|
||||
],
|
||||
user_max_budget: float,
|
||||
user_current_spend: float,
|
||||
user_info=None,
|
||||
error_message="",
|
||||
):
|
||||
if self.alerting is None or self.alert_types is None:
|
||||
# do nothing if alerting is not switched on
|
||||
return
|
||||
if "budget_alerts" not in self.alert_types:
|
||||
return
|
||||
_id: str = "default_id" # used for caching
|
||||
if type == "user_and_proxy_budget":
|
||||
user_info = dict(user_info)
|
||||
user_id = user_info["user_id"]
|
||||
_id = user_id
|
||||
max_budget = user_info["max_budget"]
|
||||
spend = user_info["spend"]
|
||||
user_email = user_info["user_email"]
|
||||
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
|
||||
elif type == "token_budget":
|
||||
token_info = dict(user_info)
|
||||
token = token_info["token"]
|
||||
_id = token
|
||||
spend = token_info["spend"]
|
||||
max_budget = token_info["max_budget"]
|
||||
user_id = token_info["user_id"]
|
||||
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
|
||||
elif type == "failed_tracking":
|
||||
user_id = str(user_info)
|
||||
_id = user_id
|
||||
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
|
||||
message = "Failed Tracking Cost for" + user_info
|
||||
await self.send_alert(
|
||||
message=message, level="High", alert_type="budget_alerts"
|
||||
)
|
||||
return
|
||||
elif type == "projected_limit_exceeded" and user_info is not None:
|
||||
"""
|
||||
Input variables:
|
||||
user_info = {
|
||||
"key_alias": key_alias,
|
||||
"projected_spend": projected_spend,
|
||||
"projected_exceeded_date": projected_exceeded_date,
|
||||
}
|
||||
user_max_budget=soft_limit,
|
||||
user_current_spend=new_spend
|
||||
"""
|
||||
message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}"""
|
||||
await self.send_alert(
|
||||
message=message, level="High", alert_type="budget_alerts"
|
||||
)
|
||||
return
|
||||
else:
|
||||
user_info = str(user_info)
|
||||
|
||||
# percent of max_budget left to spend
|
||||
if user_max_budget > 0:
|
||||
percent_left = (user_max_budget - user_current_spend) / user_max_budget
|
||||
else:
|
||||
percent_left = 0
|
||||
verbose_proxy_logger.debug(
|
||||
f"Budget Alerts: Percent left: {percent_left} for {user_info}"
|
||||
)
|
||||
|
||||
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
|
||||
# - Alert once within 28d period
|
||||
# - Cache this information
|
||||
# - Don't re-alert, if alert already sent
|
||||
_cache: DualCache = self.internal_usage_cache
|
||||
|
||||
# check if crossed budget
|
||||
if user_current_spend >= user_max_budget:
|
||||
verbose_proxy_logger.debug("Budget Crossed for %s", user_info)
|
||||
message = "Budget Crossed for" + user_info
|
||||
result = await _cache.async_get_cache(key=message)
|
||||
if result is None:
|
||||
await self.send_alert(
|
||||
message=message, level="High", alert_type="budget_alerts"
|
||||
)
|
||||
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
|
||||
return
|
||||
|
||||
# check if 5% of max budget is left
|
||||
if percent_left <= 0.05:
|
||||
message = "5% budget left for" + user_info
|
||||
cache_key = "alerting:{}".format(_id)
|
||||
result = await _cache.async_get_cache(key=cache_key)
|
||||
if result is None:
|
||||
await self.send_alert(
|
||||
message=message, level="Medium", alert_type="budget_alerts"
|
||||
)
|
||||
|
||||
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
|
||||
|
||||
return
|
||||
|
||||
# check if 15% of max budget is left
|
||||
if percent_left <= 0.15:
|
||||
message = "15% budget left for" + user_info
|
||||
result = await _cache.async_get_cache(key=message)
|
||||
if result is None:
|
||||
await self.send_alert(
|
||||
message=message, level="Low", alert_type="budget_alerts"
|
||||
)
|
||||
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
async def send_alert(
|
||||
self,
|
||||
message: str,
|
||||
level: Literal["Low", "Medium", "High"],
|
||||
alert_type: Literal[
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
],
|
||||
):
|
||||
"""
|
||||
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||
|
||||
- Responses taking too long
|
||||
- Requests are hanging
|
||||
- Calls are failing
|
||||
- DB Read/Writes are failing
|
||||
- Proxy Close to max budget
|
||||
- Key Close to max budget
|
||||
|
||||
Parameters:
|
||||
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
|
||||
message: str - what is the alert about
|
||||
"""
|
||||
if self.alerting is None:
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
# Get the current timestamp
|
||||
current_time = datetime.now().strftime("%H:%M:%S")
|
||||
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
||||
formatted_message = (
|
||||
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
||||
)
|
||||
if _proxy_base_url is not None:
|
||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||
|
||||
# check if we find the slack webhook url in self.alert_to_webhook_url
|
||||
if (
|
||||
self.alert_to_webhook_url is not None
|
||||
and alert_type in self.alert_to_webhook_url
|
||||
):
|
||||
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
||||
else:
|
||||
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
||||
|
||||
if slack_webhook_url is None:
|
||||
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
|
||||
payload = {"text": formatted_message}
|
||||
headers = {"Content-type": "application/json"}
|
||||
|
||||
response = await self.async_http_handler.post(
|
||||
url=slack_webhook_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
if response.status_code == 200:
|
||||
pass
|
||||
else:
|
||||
print("Error sending slack alert. Error=", response.text) # noqa
|
|
@ -298,7 +298,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -2,18 +2,13 @@ import os, types
|
|||
import json
|
||||
from enum import Enum
|
||||
import requests, copy
|
||||
import time, uuid
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import (
|
||||
contains_tag,
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
construct_tool_use_system_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
)
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx
|
||||
|
||||
|
||||
|
@ -21,6 +16,8 @@ class AnthropicConstants(Enum):
|
|||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
|
||||
|
||||
|
||||
class AnthropicError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -37,12 +34,14 @@ class AnthropicError(Exception):
|
|||
|
||||
class AnthropicConfig:
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/complete_post
|
||||
Reference: https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default
|
||||
max_tokens: Optional[int] = (
|
||||
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||||
)
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
|
@ -52,7 +51,9 @@ class AnthropicConfig:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = 256, # anthropic requires a default
|
||||
max_tokens: Optional[
|
||||
int
|
||||
] = 4096, # You can pass in a value yourself or use the default value 4096
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
|
@ -101,124 +102,23 @@ def validate_environment(api_key, user_headers):
|
|||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
_is_function_call = False
|
||||
json_schemas: dict = {}
|
||||
messages = copy.deepcopy(messages)
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_prompt += message["content"]
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
if len(system_prompt) > 0:
|
||||
optional_params["system"] = system_prompt
|
||||
# Format rest of message according to anthropic guidelines
|
||||
try:
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AnthropicError(status_code=400, message=str(e))
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
for tool in optional_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||
"parameters", None
|
||||
)
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=optional_params["tools"]
|
||||
)
|
||||
optional_params["system"] = (
|
||||
optional_params.get("system", "\n") + tool_calling_system_prompt
|
||||
) # add the anthropic tool calling prompt to the system prompt
|
||||
optional_params.pop("tools")
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
## COMPLETION CALL
|
||||
if (
|
||||
stream is not None and stream == True and _is_function_call == False
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose(f"makes anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
class AnthropicChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model,
|
||||
response,
|
||||
model_response,
|
||||
_is_function_call,
|
||||
stream,
|
||||
logging_obj,
|
||||
api_key,
|
||||
data,
|
||||
messages,
|
||||
print_verbose,
|
||||
):
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
@ -245,46 +145,40 @@ def completion(
|
|||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
text_content = completion_response["content"][0].get("text", None)
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if text_content is not None and contains_tag("invoke", text_content):
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", text_content)[
|
||||
0
|
||||
].strip()
|
||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||
function_arguments = parse_xml_params(
|
||||
function_arguments_str,
|
||||
json_schema=json_schemas.get(
|
||||
function_name, None
|
||||
), # check if we have a json schema for this function name
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = (
|
||||
text_content # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
)
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
] # allow user to access raw anthropic tool calling response
|
||||
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
||||
if _is_function_call == True and stream is not None and stream == True:
|
||||
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
if _is_function_call and stream:
|
||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
||||
|
@ -318,7 +212,7 @@ def completion(
|
|||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -337,11 +231,278 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
data["stream"] = True
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data), stream=True
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def acompletion_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
_is_function_call = False
|
||||
messages = copy.deepcopy(messages)
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_prompt += message["content"]
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
if len(system_prompt) > 0:
|
||||
optional_params["system"] = system_prompt
|
||||
# Format rest of message according to anthropic guidelines
|
||||
try:
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AnthropicError(status_code=400, message=str(e))
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
headers["anthropic-beta"] = "tools-2024-04-04"
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in optional_params["tools"]:
|
||||
new_tool = tool["function"]
|
||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||
anthropic_tools.append(new_tool)
|
||||
|
||||
optional_params["tools"] = anthropic_tools
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
if acompletion == True:
|
||||
if (
|
||||
stream and not _is_function_call
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes async anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if (
|
||||
stream and not _is_function_call
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.iter_lines()
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
else:
|
||||
response = requests.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
|
@ -367,8 +528,3 @@ class ModelResponseIterator:
|
|||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -4,10 +4,12 @@ from enum import Enum
|
|||
import requests
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
import httpx
|
||||
from .base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
|
@ -94,91 +96,13 @@ def validate_environment(api_key, user_headers):
|
|||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
class AnthropicTextCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
def process_response(
|
||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||
):
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
|
@ -213,10 +137,208 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
encoding,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
client=None,
|
||||
):
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data["prompt"],
|
||||
api_key=headers.get("x-api-key"),
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
response = self.process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
prompt=data["prompt"],
|
||||
model=model,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
data: Optional[dict],
|
||||
client=None,
|
||||
):
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
acompletion: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if acompletion == True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
headers=headers,
|
||||
data=data,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
# stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
completion_stream = response.iter_lines()
|
||||
stream_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return stream_response
|
||||
elif acompletion == True:
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
data=data,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
if client is None:
|
||||
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
response = self.process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
prompt=data["prompt"],
|
||||
model=model,
|
||||
)
|
||||
return response
|
||||
|
||||
def embedding(self):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -799,6 +799,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
optional_params: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
|
@ -817,8 +818,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
"timeout": timeout,
|
||||
}
|
||||
|
||||
max_retries = optional_params.pop("max_retries", None)
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ from litellm.utils import (
|
|||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
TranscriptionResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
from typing import Callable, Optional, BinaryIO
|
||||
from litellm import OpenAIConfig
|
||||
|
@ -15,11 +16,11 @@ import litellm, json
|
|||
import httpx
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
from ..llms.openai import OpenAITextCompletion
|
||||
from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
|
||||
import uuid
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
openai_text_completion = OpenAITextCompletion()
|
||||
openai_text_completion_config = OpenAITextCompletionConfig()
|
||||
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -300,9 +301,11 @@ class AzureTextCompletion(BaseLLM):
|
|||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
return openai_text_completion.convert_to_model_response_object(
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
return (
|
||||
openai_text_completion_config.convert_to_chat_model_response_object(
|
||||
response_object=TextCompletionResponse(**stringified_response),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
|
@ -373,7 +376,7 @@ class AzureTextCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
response = await azure_client.completions.create(**data, timeout=timeout)
|
||||
return openai_text_completion.convert_to_model_response_object(
|
||||
return openai_text_completion_config.convert_to_chat_model_response_object(
|
||||
response_object=response.model_dump(),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
|
|
|
@ -55,9 +55,11 @@ def completion(
|
|||
"inputs": prompt,
|
||||
"prompt": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
"stream": (
|
||||
True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
@ -71,9 +73,11 @@ def completion(
|
|||
completion_url_fragment_1 + model + completion_url_fragment_2,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
stream=(
|
||||
True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
)
|
||||
if "text/event-stream" in response.headers["Content-Type"] or (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
|
@ -102,28 +106,28 @@ def completion(
|
|||
and "data" in completion_response["model_output"]
|
||||
and isinstance(completion_response["model_output"]["data"], list)
|
||||
):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["model_output"]["data"][0]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["model_output"]["data"][0]
|
||||
)
|
||||
elif isinstance(completion_response["model_output"], str):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["model_output"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["model_output"]
|
||||
)
|
||||
elif "completion" in completion_response and isinstance(
|
||||
completion_response["completion"], str
|
||||
):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["completion"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["completion"]
|
||||
)
|
||||
elif isinstance(completion_response, list) and len(completion_response) > 0:
|
||||
if "generated_text" not in completion_response:
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response[0]["generated_text"]
|
||||
)
|
||||
## GETTING LOGPROBS
|
||||
if (
|
||||
"details" in completion_response[0]
|
||||
|
@ -155,7 +159,8 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -653,6 +653,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
|||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
|
@ -746,7 +750,7 @@ def completion(
|
|||
]
|
||||
# Format rest of message according to anthropic guidelines
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
||||
)
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicClaude3Config.get_config()
|
||||
|
@ -1008,7 +1012,7 @@ def completion(
|
|||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = model_response_iterator(
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
|
@ -1028,7 +1032,7 @@ def completion(
|
|||
total_tokens=response_body["usage"]["input_tokens"]
|
||||
+ response_body["usage"]["output_tokens"],
|
||||
)
|
||||
model_response.usage = _usage
|
||||
setattr(model_response, "usage", _usage)
|
||||
else:
|
||||
outputText = response_body["completion"]
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
|
@ -1071,8 +1075,10 @@ def completion(
|
|||
status_code=response_metadata.get("HTTPStatusCode", 500),
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
if getattr(model_response.usage, "total_tokens", None) is None:
|
||||
## CALCULATING USAGE - bedrock charges on time, not tokens - have some mapping of cost here.
|
||||
if not hasattr(model_response, "usage"):
|
||||
setattr(model_response, "usage", Usage())
|
||||
if getattr(model_response.usage, "total_tokens", None) is None: # type: ignore
|
||||
prompt_tokens = response_metadata.get(
|
||||
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
|
||||
)
|
||||
|
@ -1089,7 +1095,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
|
@ -1109,8 +1115,30 @@ def completion(
|
|||
raise BedrockError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
|
||||
async def model_response_iterator(model_response):
|
||||
yield model_response
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
|
||||
def _embedding_func_single(
|
||||
|
|
|
@ -167,7 +167,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -237,7 +237,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ class CohereChatConfig:
|
|||
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
|
||||
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
|
||||
seed (int, optional): A seed to assist reproducibility of the model's response.
|
||||
"""
|
||||
|
||||
preamble: Optional[str] = None
|
||||
|
@ -62,6 +63,7 @@ class CohereChatConfig:
|
|||
presence_penalty: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_results: Optional[list] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -82,6 +84,7 @@ class CohereChatConfig:
|
|||
presence_penalty: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_results: Optional[list] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
|
@ -302,5 +305,5 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
|
96
litellm/llms/custom_httpx/http_handler.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import httpx, asyncio
|
||||
from typing import Optional, Union, Mapping, Any
|
||||
|
||||
# https://www.python-httpx.org/advanced/timeouts
|
||||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||
|
||||
|
||||
class AsyncHTTPHandler:
|
||||
def __init__(
|
||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
||||
):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
# Close the client when you're done with it
|
||||
await self.client.aclose()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.client
|
||||
|
||||
async def __aexit__(self):
|
||||
# close the client when exiting
|
||||
await self.client.aclose()
|
||||
|
||||
async def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
):
|
||||
response = await self.client.get(url, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = await self.client.send(req, stream=stream)
|
||||
return response
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(self.close())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(
|
||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
||||
):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
)
|
||||
|
||||
def close(self):
|
||||
# Close the client when you're done with it
|
||||
self.client.close()
|
||||
|
||||
def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
):
|
||||
response = self.client.get(url, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
response = self.client.post(url, data=data, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
|
@ -6,7 +6,8 @@ from typing import Callable, Optional
|
|||
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
||||
import litellm
|
||||
import sys, httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
class GeminiError(Exception):
|
||||
|
@ -103,6 +104,13 @@ class TextStreamer:
|
|||
break
|
||||
|
||||
|
||||
def supports_system_instruction():
|
||||
import google.generativeai as genai
|
||||
|
||||
gemini_pkg_version = Version(genai.__version__)
|
||||
return gemini_pkg_version >= Version("0.5.0")
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -124,7 +132,7 @@ def completion(
|
|||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||
)
|
||||
genai.configure(api_key=api_key)
|
||||
|
||||
system_prompt = ""
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
|
@ -135,6 +143,7 @@ def completion(
|
|||
messages=messages,
|
||||
)
|
||||
else:
|
||||
system_prompt, messages = get_system_prompt(messages=messages)
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="gemini"
|
||||
)
|
||||
|
@ -162,11 +171,20 @@ def completion(
|
|||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": {"inference_params": inference_params}},
|
||||
additional_args={
|
||||
"complete_input_dict": {
|
||||
"inference_params": inference_params,
|
||||
"system_prompt": system_prompt,
|
||||
}
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
_model = genai.GenerativeModel(f"models/{model}")
|
||||
_params = {"model_name": "models/{}".format(model)}
|
||||
_system_instruction = supports_system_instruction()
|
||||
if _system_instruction and len(system_prompt) > 0:
|
||||
_params["system_instruction"] = system_prompt
|
||||
_model = genai.GenerativeModel(**_params)
|
||||
if stream == True:
|
||||
if acompletion == True:
|
||||
|
||||
|
@ -213,11 +231,12 @@ def completion(
|
|||
encoding=encoding,
|
||||
)
|
||||
else:
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
params = {
|
||||
"contents": prompt,
|
||||
"generation_config": genai.types.GenerationConfig(**inference_params),
|
||||
"safety_settings": safety_settings,
|
||||
}
|
||||
response = _model.generate_content(**params)
|
||||
except Exception as e:
|
||||
raise GeminiError(
|
||||
message=str(e),
|
||||
|
@ -292,7 +311,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -152,9 +152,9 @@ def completion(
|
|||
else:
|
||||
try:
|
||||
if len(completion_response["answer"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["answer"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["answer"]
|
||||
)
|
||||
except Exception as e:
|
||||
raise MaritalkError(
|
||||
message=response.text, status_code=response.status_code
|
||||
|
@ -174,7 +174,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -185,9 +185,9 @@ def completion(
|
|||
else:
|
||||
try:
|
||||
if len(completion_response["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["generated_text"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["generated_text"]
|
||||
)
|
||||
except:
|
||||
raise NLPCloudError(
|
||||
message=json.dumps(completion_response),
|
||||
|
@ -205,7 +205,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class OllamaError(Exception):
|
|||
|
||||
class OllamaConfig:
|
||||
"""
|
||||
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
|
||||
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
|
||||
|
||||
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
|
||||
|
||||
|
@ -69,7 +69,7 @@ class OllamaConfig:
|
|||
repeat_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[list] = (
|
||||
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
|
||||
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
|
||||
)
|
||||
tfs_z: Optional[float] = None
|
||||
num_predict: Optional[int] = None
|
||||
|
@ -228,8 +228,8 @@ def get_ollama_response(
|
|||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
|
||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
@ -330,8 +330,8 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + data["model"]
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
|
||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
|
|
@ -20,7 +20,7 @@ class OllamaError(Exception):
|
|||
|
||||
class OllamaChatConfig:
|
||||
"""
|
||||
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
|
||||
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#parameters
|
||||
|
||||
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
|
||||
|
||||
|
@ -69,7 +69,7 @@ class OllamaChatConfig:
|
|||
repeat_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[list] = (
|
||||
None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
|
||||
None # stop is a list based on this - https://github.com/ollama/ollama/pull/442
|
||||
)
|
||||
tfs_z: Optional[float] = None
|
||||
num_predict: Optional[int] = None
|
||||
|
@ -148,7 +148,7 @@ class OllamaChatConfig:
|
|||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["repeat_penalty"] = param
|
||||
optional_params["repeat_penalty"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "response_format" and value["type"] == "json_object":
|
||||
|
@ -184,6 +184,7 @@ class OllamaChatConfig:
|
|||
# ollama implementation
|
||||
def get_ollama_response(
|
||||
api_base="http://localhost:11434",
|
||||
api_key: Optional[str] = None,
|
||||
model="llama2",
|
||||
messages=None,
|
||||
optional_params=None,
|
||||
|
@ -236,6 +237,7 @@ def get_ollama_response(
|
|||
if stream == True:
|
||||
response = ollama_async_streaming(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
|
@ -244,6 +246,7 @@ def get_ollama_response(
|
|||
else:
|
||||
response = ollama_acompletion(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
|
@ -252,12 +255,17 @@ def get_ollama_response(
|
|||
)
|
||||
return response
|
||||
elif stream == True:
|
||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||
return ollama_completion_stream(
|
||||
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
response = requests.post(**_request) # type: ignore
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
|
@ -307,10 +315,16 @@ def get_ollama_response(
|
|||
return model_response
|
||||
|
||||
|
||||
def ollama_completion_stream(url, data, logging_obj):
|
||||
with httpx.stream(
|
||||
url=url, json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
"method": "POST",
|
||||
"timeout": litellm.request_timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
with httpx.stream(**_request) as response:
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(
|
||||
|
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
|
|||
raise e
|
||||
|
||||
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
async def ollama_async_streaming(
|
||||
url, api_key, data, model_response, encoding, logging_obj
|
||||
):
|
||||
try:
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
"method": "POST",
|
||||
"timeout": litellm.request_timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
async with client.stream(**_request) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(
|
||||
status_code=response.status_code, message=response.text
|
||||
|
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
|||
|
||||
|
||||
async def ollama_acompletion(
|
||||
url, data, model_response, encoding, logging_obj, function_name
|
||||
url,
|
||||
api_key: Optional[str],
|
||||
data,
|
||||
model_response,
|
||||
encoding,
|
||||
logging_obj,
|
||||
function_name,
|
||||
):
|
||||
data["stream"] = False
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
resp = await session.post(url, json=data)
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
resp = await session.post(**_request)
|
||||
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
|
|
|
@ -99,9 +99,9 @@ def completion(
|
|||
)
|
||||
else:
|
||||
try:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["choices"][0]["message"]["content"]
|
||||
model_response["choices"][0]["message"]["content"] = (
|
||||
completion_response["choices"][0]["message"]["content"]
|
||||
)
|
||||
except:
|
||||
raise OobaboogaError(
|
||||
message=json.dumps(completion_response),
|
||||
|
@ -115,7 +115,7 @@ def completion(
|
|||
completion_tokens=completion_response["usage"]["completion_tokens"],
|
||||
total_tokens=completion_response["usage"]["total_tokens"],
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from litellm.utils import (
|
|||
convert_to_model_response_object,
|
||||
Usage,
|
||||
TranscriptionResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
from typing import Callable, Optional
|
||||
import aiohttp, requests
|
||||
|
@ -200,6 +201,43 @@ class OpenAITextCompletionConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def convert_to_chat_model_response_object(
|
||||
self,
|
||||
response_object: Optional[TextCompletionResponse] = None,
|
||||
model_response_object: Optional[ModelResponse] = None,
|
||||
):
|
||||
try:
|
||||
## RESPONSE OBJECT
|
||||
if response_object is None or model_response_object is None:
|
||||
raise ValueError("Error in response object format")
|
||||
choice_list = []
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(
|
||||
content=choice["text"],
|
||||
role="assistant",
|
||||
)
|
||||
choice = Choices(
|
||||
finish_reason=choice["finish_reason"], index=idx, message=message
|
||||
)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list
|
||||
|
||||
if "usage" in response_object:
|
||||
setattr(model_response_object, "usage", response_object["usage"])
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
model_response_object._hidden_params["original_response"] = (
|
||||
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
|
||||
)
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class OpenAIChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
|
@ -785,10 +823,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
optional_params: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None,
|
||||
atranscription: bool = False,
|
||||
):
|
||||
|
@ -962,40 +1000,6 @@ class OpenAITextCompletion(BaseLLM):
|
|||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def convert_to_model_response_object(
|
||||
self,
|
||||
response_object: Optional[dict] = None,
|
||||
model_response_object: Optional[ModelResponse] = None,
|
||||
):
|
||||
try:
|
||||
## RESPONSE OBJECT
|
||||
if response_object is None or model_response_object is None:
|
||||
raise ValueError("Error in response object format")
|
||||
choice_list = []
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(content=choice["text"], role="assistant")
|
||||
choice = Choices(
|
||||
finish_reason=choice["finish_reason"], index=idx, message=message
|
||||
)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list
|
||||
|
||||
if "usage" in response_object:
|
||||
model_response_object.usage = response_object["usage"]
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
model_response_object._hidden_params["original_response"] = (
|
||||
response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
|
||||
)
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
|
@ -1010,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
|
|||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
super().completion()
|
||||
|
@ -1020,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
|
|||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
api_base = f"{api_base}/completions"
|
||||
|
||||
if (
|
||||
len(messages) > 0
|
||||
and "content" in messages[0]
|
||||
|
@ -1029,12 +1033,12 @@ class OpenAITextCompletion(BaseLLM):
|
|||
):
|
||||
prompt = messages[0]["content"]
|
||||
else:
|
||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||
prompt = [message["content"] for message in messages] # type: ignore
|
||||
|
||||
# don't send max retries to the api, if set
|
||||
optional_params.pop("max_retries", None)
|
||||
|
||||
data = {"model": model, "prompt": prompt, **optional_params}
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
|
@ -1050,38 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
|
|||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
response = httpx.post(
|
||||
url=f"{api_base}", json=data, headers=headers, timeout=timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code, message=response.text
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
response = openai_client.completions.create(**data) # type: ignore
|
||||
|
||||
response_json = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
|
@ -1089,10 +1108,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(
|
||||
response_object=response.json(),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
return TextCompletionResponse(**response_json)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -1107,101 +1123,112 @@ class OpenAITextCompletion(BaseLLM):
|
|||
api_key: str,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries=None,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
api_base,
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=litellm.request_timeout,
|
||||
)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
base_url=api_base,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_aclient = client
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(
|
||||
response_object=response_json, model_response_object=model_response
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
response = await openai_aclient.completions.create(**data)
|
||||
response_json = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
response_obj = TextCompletionResponse(**response_json)
|
||||
response_obj._hidden_params.original_response = json.dumps(response_json)
|
||||
return response_obj
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
api_base: Optional[str] = None,
|
||||
max_retries=None,
|
||||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
with httpx.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response.iter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
else:
|
||||
openai_client = client
|
||||
response = openai_client.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
for chunk in streamwrapper:
|
||||
yield chunk
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
organization=None,
|
||||
):
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
if client is None:
|
||||
openai_client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response.aiter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
raise e
|
||||
response = await openai_client.completions.create(**data)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
|
|
@ -191,7 +191,7 @@ def completion(
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|