mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-28 04:04:31 +00:00
Merge branch 'BerriAI:main' into ollama-image-handling
This commit is contained in:
commit
ea117fc859
94 changed files with 10050 additions and 828 deletions
|
@ -129,6 +129,7 @@ jobs:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
machine:
|
machine:
|
||||||
image: ubuntu-2204:2023.10.1
|
image: ubuntu-2204:2023.10.1
|
||||||
|
resource_class: xlarge
|
||||||
working_directory: ~/project
|
working_directory: ~/project
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
|
@ -188,6 +189,9 @@ jobs:
|
||||||
-p 4000:4000 \
|
-p 4000:4000 \
|
||||||
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
||||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||||
|
-e REDIS_HOST=$REDIS_HOST \
|
||||||
|
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||||
|
-e REDIS_PORT=$REDIS_PORT \
|
||||||
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
||||||
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
||||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/docs
|
docs
|
||||||
/cookbook
|
cookbook
|
||||||
/.circleci
|
.circleci
|
||||||
/.github
|
.github
|
||||||
/tests
|
tests
|
||||||
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -46,3 +46,7 @@ deploy/charts/*.tgz
|
||||||
litellm/proxy/vertex_key.json
|
litellm/proxy/vertex_key.json
|
||||||
**/.vim/
|
**/.vim/
|
||||||
/node_modules
|
/node_modules
|
||||||
|
kub.yaml
|
||||||
|
loadtest_kub.yaml
|
||||||
|
litellm/proxy/_new_secret_config.yaml
|
||||||
|
litellm/proxy/_new_secret_config.yaml
|
||||||
|
|
|
@ -70,5 +70,4 @@ EXPOSE 4000/tcp
|
||||||
ENTRYPOINT ["litellm"]
|
ENTRYPOINT ["litellm"]
|
||||||
|
|
||||||
# Append "--detailed_debug" to the end of CMD to view detailed debug logs
|
# Append "--detailed_debug" to the end of CMD to view detailed debug logs
|
||||||
# CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
CMD ["--port", "4000"]
|
||||||
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
|
||||||
|
|
|
@ -205,7 +205,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
| [aws - bedrock](https://docs.litellm.ai/docs/providers/bedrock) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| [aws - bedrock](https://docs.litellm.ai/docs/providers/bedrock) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [google - vertex_ai [Gemini]](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
|
| [google - vertex_ai [Gemini]](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
|
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [google AI Studio - gemini](https://docs.litellm.ai/docs/providers/gemini) | ✅ | | ✅ | | |
|
| [google AI Studio - gemini](https://docs.litellm.ai/docs/providers/gemini) | ✅ | ✅ | ✅ | ✅ | |
|
||||||
| [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [cloudflare AI Workers](https://docs.litellm.ai/docs/providers/cloudflare_workers) | ✅ | ✅ | ✅ | ✅ |
|
| [cloudflare AI Workers](https://docs.litellm.ai/docs/providers/cloudflare_workers) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [cohere](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| [cohere](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
@ -220,7 +220,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
| [nlp_cloud](https://docs.litellm.ai/docs/providers/nlp_cloud) | ✅ | ✅ | ✅ | ✅ |
|
| [nlp_cloud](https://docs.litellm.ai/docs/providers/nlp_cloud) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [aleph alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ✅ | ✅ | ✅ | ✅ |
|
| [aleph alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [petals](https://docs.litellm.ai/docs/providers/petals) | ✅ | ✅ | ✅ | ✅ |
|
| [petals](https://docs.litellm.ai/docs/providers/petals) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ |
|
| [ollama](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
|
204
cookbook/Proxy_Batch_Users.ipynb
vendored
Normal file
204
cookbook/Proxy_Batch_Users.ipynb
vendored
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "680oRk1af-xJ"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Environment Setup"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "X7TgJFn8f88p"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import csv\n",
|
||||||
|
"from typing import Optional\n",
|
||||||
|
"import httpx, json\n",
|
||||||
|
"import asyncio\n",
|
||||||
|
"\n",
|
||||||
|
"proxy_base_url = \"http://0.0.0.0:4000\" # 👈 SET TO PROXY URL\n",
|
||||||
|
"master_key = \"sk-1234\" # 👈 SET TO PROXY MASTER KEY"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "rauw8EOhgBz5"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"## GLOBAL HTTP CLIENT ## - faster http calls\n",
|
||||||
|
"class HTTPHandler:\n",
|
||||||
|
" def __init__(self, concurrent_limit=1000):\n",
|
||||||
|
" # Create a client with a connection pool\n",
|
||||||
|
" self.client = httpx.AsyncClient(\n",
|
||||||
|
" limits=httpx.Limits(\n",
|
||||||
|
" max_connections=concurrent_limit,\n",
|
||||||
|
" max_keepalive_connections=concurrent_limit,\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" async def close(self):\n",
|
||||||
|
" # Close the client when you're done with it\n",
|
||||||
|
" await self.client.aclose()\n",
|
||||||
|
"\n",
|
||||||
|
" async def get(\n",
|
||||||
|
" self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None\n",
|
||||||
|
" ):\n",
|
||||||
|
" response = await self.client.get(url, params=params, headers=headers)\n",
|
||||||
|
" return response\n",
|
||||||
|
"\n",
|
||||||
|
" async def post(\n",
|
||||||
|
" self,\n",
|
||||||
|
" url: str,\n",
|
||||||
|
" data: Optional[dict] = None,\n",
|
||||||
|
" params: Optional[dict] = None,\n",
|
||||||
|
" headers: Optional[dict] = None,\n",
|
||||||
|
" ):\n",
|
||||||
|
" try:\n",
|
||||||
|
" response = await self.client.post(\n",
|
||||||
|
" url, data=data, params=params, headers=headers\n",
|
||||||
|
" )\n",
|
||||||
|
" return response\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" raise e\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "7LXN8zaLgOie"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Import Sheet\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Format: | ID | Name | Max Budget |"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "oiED0usegPGf"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"async def import_sheet():\n",
|
||||||
|
" tasks = []\n",
|
||||||
|
" http_client = HTTPHandler()\n",
|
||||||
|
" with open('my-batch-sheet.csv', 'r') as file:\n",
|
||||||
|
" csv_reader = csv.DictReader(file)\n",
|
||||||
|
" for row in csv_reader:\n",
|
||||||
|
" task = create_user(client=http_client, user_id=row['ID'], max_budget=row['Max Budget'], user_name=row['Name'])\n",
|
||||||
|
" tasks.append(task)\n",
|
||||||
|
" # print(f\"ID: {row['ID']}, Name: {row['Name']}, Max Budget: {row['Max Budget']}\")\n",
|
||||||
|
"\n",
|
||||||
|
" keys = await asyncio.gather(*tasks)\n",
|
||||||
|
"\n",
|
||||||
|
" with open('my-batch-sheet_new.csv', 'w', newline='') as new_file:\n",
|
||||||
|
" fieldnames = ['ID', 'Name', 'Max Budget', 'keys']\n",
|
||||||
|
" csv_writer = csv.DictWriter(new_file, fieldnames=fieldnames)\n",
|
||||||
|
" csv_writer.writeheader()\n",
|
||||||
|
"\n",
|
||||||
|
" with open('my-batch-sheet.csv', 'r') as file:\n",
|
||||||
|
" csv_reader = csv.DictReader(file)\n",
|
||||||
|
" for i, row in enumerate(csv_reader):\n",
|
||||||
|
" row['keys'] = keys[i] # Add the 'keys' value from the corresponding task result\n",
|
||||||
|
" csv_writer.writerow(row)\n",
|
||||||
|
"\n",
|
||||||
|
" await http_client.close()\n",
|
||||||
|
"\n",
|
||||||
|
"asyncio.run(import_sheet())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "E7M0Li_UgJeZ"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Create Users + Keys\n",
|
||||||
|
"\n",
|
||||||
|
"- Creates a user\n",
|
||||||
|
"- Creates a key with max budget"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "NZudRFujf7j-"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"async def create_key_with_alias(client: HTTPHandler, user_id: str, max_budget: float):\n",
|
||||||
|
" global proxy_base_url\n",
|
||||||
|
" if not proxy_base_url.endswith(\"/\"):\n",
|
||||||
|
" proxy_base_url += \"/\"\n",
|
||||||
|
" url = proxy_base_url + \"key/generate\"\n",
|
||||||
|
"\n",
|
||||||
|
" # call /key/generate\n",
|
||||||
|
" print(\"CALLING /KEY/GENERATE\")\n",
|
||||||
|
" response = await client.post(\n",
|
||||||
|
" url=url,\n",
|
||||||
|
" headers={\"Authorization\": f\"Bearer {master_key}\"},\n",
|
||||||
|
" data=json.dumps({\n",
|
||||||
|
" \"user_id\": user_id,\n",
|
||||||
|
" \"key_alias\": f\"{user_id}-key\",\n",
|
||||||
|
" \"max_budget\": max_budget # 👈 KEY CHANGE: SETS MAX BUDGET PER KEY\n",
|
||||||
|
" })\n",
|
||||||
|
" )\n",
|
||||||
|
" print(f\"response: {response.text}\")\n",
|
||||||
|
" return response.json()[\"key\"]\n",
|
||||||
|
"\n",
|
||||||
|
"async def create_user(client: HTTPHandler, user_id: str, max_budget: float, user_name: str):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" - call /user/new\n",
|
||||||
|
" - create key for user\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" global proxy_base_url\n",
|
||||||
|
" if not proxy_base_url.endswith(\"/\"):\n",
|
||||||
|
" proxy_base_url += \"/\"\n",
|
||||||
|
" url = proxy_base_url + \"user/new\"\n",
|
||||||
|
"\n",
|
||||||
|
" # call /user/new\n",
|
||||||
|
" await client.post(\n",
|
||||||
|
" url=url,\n",
|
||||||
|
" headers={\"Authorization\": f\"Bearer {master_key}\"},\n",
|
||||||
|
" data=json.dumps({\n",
|
||||||
|
" \"user_id\": user_id,\n",
|
||||||
|
" \"user_alias\": user_name,\n",
|
||||||
|
" \"auto_create_key\": False,\n",
|
||||||
|
" # \"max_budget\": max_budget # 👈 [OPTIONAL] Sets max budget per user (if you want to set a max budget across keys)\n",
|
||||||
|
" })\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # create key for user\n",
|
||||||
|
" return await create_key_with_alias(client=client, user_id=user_id, max_budget=max_budget)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
73
cookbook/misc/config.yaml
Normal file
73
cookbook/misc/config.yaml
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
api_version: "2023-05-15"
|
||||||
|
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||||
|
- model_name: gpt-3.5-turbo-large
|
||||||
|
litellm_params:
|
||||||
|
model: "gpt-3.5-turbo-1106"
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
rpm: 480
|
||||||
|
timeout: 300
|
||||||
|
stream_timeout: 60
|
||||||
|
- model_name: gpt-4
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
api_version: "2023-05-15"
|
||||||
|
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||||
|
rpm: 480
|
||||||
|
timeout: 300
|
||||||
|
stream_timeout: 60
|
||||||
|
- model_name: sagemaker-completion-model
|
||||||
|
litellm_params:
|
||||||
|
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
||||||
|
input_cost_per_second: 0.000420
|
||||||
|
- model_name: text-embedding-ada-002
|
||||||
|
litellm_params:
|
||||||
|
model: azure/azure-embedding-model
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
api_version: "2023-05-15"
|
||||||
|
model_info:
|
||||||
|
mode: embedding
|
||||||
|
base_model: text-embedding-ada-002
|
||||||
|
- model_name: dall-e-2
|
||||||
|
litellm_params:
|
||||||
|
model: azure/
|
||||||
|
api_version: 2023-06-01-preview
|
||||||
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
- model_name: openai-dall-e-3
|
||||||
|
litellm_params:
|
||||||
|
model: dall-e-3
|
||||||
|
- model_name: fake-openai-endpoint
|
||||||
|
litellm_params:
|
||||||
|
model: openai/fake
|
||||||
|
api_key: fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
||||||
|
# max_budget: 100
|
||||||
|
# budget_duration: 30d
|
||||||
|
num_retries: 5
|
||||||
|
request_timeout: 600
|
||||||
|
telemetry: False
|
||||||
|
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||||
|
store_model_in_db: True
|
||||||
|
proxy_budget_rescheduler_min_time: 60
|
||||||
|
proxy_budget_rescheduler_max_time: 64
|
||||||
|
proxy_batch_write_at: 1
|
||||||
|
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
|
||||||
|
|
||||||
|
# environment_variables:
|
||||||
|
# settings for using redis caching
|
||||||
|
# REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
||||||
|
# REDIS_PORT: "16337"
|
||||||
|
# REDIS_PASSWORD:
|
92
cookbook/misc/migrate_proxy_config.py
Normal file
92
cookbook/misc/migrate_proxy_config.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
"""
|
||||||
|
LiteLLM Migration Script!
|
||||||
|
|
||||||
|
Takes a config.yaml and calls /model/new
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- File path to config.yaml
|
||||||
|
- Proxy base url to your hosted proxy
|
||||||
|
|
||||||
|
Step 1: Reads your config.yaml
|
||||||
|
Step 2: reads `model_list` and loops through all models
|
||||||
|
Step 3: calls `<proxy-base-url>/model/new` for each model
|
||||||
|
"""
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import requests
|
||||||
|
|
||||||
|
_in_memory_os_variables = {}
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_models(config_file, proxy_base_url):
|
||||||
|
# Step 1: Read the config.yaml file
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Step 2: Read the model_list and loop through all models
|
||||||
|
model_list = config.get("model_list", [])
|
||||||
|
print("model_list: ", model_list)
|
||||||
|
for model in model_list:
|
||||||
|
|
||||||
|
model_name = model.get("model_name")
|
||||||
|
print("\nAdding model: ", model_name)
|
||||||
|
litellm_params = model.get("litellm_params", {})
|
||||||
|
api_base = litellm_params.get("api_base", "")
|
||||||
|
print("api_base on config.yaml: ", api_base)
|
||||||
|
|
||||||
|
litellm_model_name = litellm_params.get("model", "") or ""
|
||||||
|
if "vertex_ai/" in litellm_model_name:
|
||||||
|
print(f"\033[91m\nSkipping Vertex AI model\033[0m", model)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param, value in litellm_params.items():
|
||||||
|
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||||
|
# check if value is in _in_memory_os_variables
|
||||||
|
if value in _in_memory_os_variables:
|
||||||
|
new_value = _in_memory_os_variables[value]
|
||||||
|
print(
|
||||||
|
"\033[92mAlready entered value for \033[0m",
|
||||||
|
value,
|
||||||
|
"\033[92musing \033[0m",
|
||||||
|
new_value,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_value = input(f"Enter value for {value}: ")
|
||||||
|
_in_memory_os_variables[value] = new_value
|
||||||
|
litellm_params[param] = new_value
|
||||||
|
|
||||||
|
print("\nlitellm_params: ", litellm_params)
|
||||||
|
# Confirm before sending POST request
|
||||||
|
confirm = input(
|
||||||
|
"\033[92mDo you want to send the POST request with the above parameters? (y/n): \033[0m"
|
||||||
|
)
|
||||||
|
if confirm.lower() != "y":
|
||||||
|
print("Aborting POST request.")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# Step 3: Call <proxy-base-url>/model/new for each model
|
||||||
|
url = f"{proxy_base_url}/model/new"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {master_key}",
|
||||||
|
}
|
||||||
|
data = {"model_name": model_name, "litellm_params": litellm_params}
|
||||||
|
print("POSTING data to proxy url", url)
|
||||||
|
response = requests.post(url, headers=headers, json=data)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error: {response.status_code} - {response.text}")
|
||||||
|
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
# Print the response for each model
|
||||||
|
print(
|
||||||
|
f"Response for model '{model_name}': Status Code:{response.status_code} - {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
config_file = "config.yaml"
|
||||||
|
proxy_base_url = "http://0.0.0.0:4000"
|
||||||
|
master_key = "sk-1234"
|
||||||
|
print(f"config_file: {config_file}")
|
||||||
|
print(f"proxy_base_url: {proxy_base_url}")
|
||||||
|
migrate_models(config_file, proxy_base_url)
|
|
@ -1,10 +1,16 @@
|
||||||
version: "3.9"
|
version: "3.9"
|
||||||
services:
|
services:
|
||||||
litellm:
|
litellm:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
args:
|
||||||
|
target: runtime
|
||||||
image: ghcr.io/berriai/litellm:main-latest
|
image: ghcr.io/berriai/litellm:main-latest
|
||||||
volumes:
|
|
||||||
- ./proxy_server_config.yaml:/app/proxy_server_config.yaml # mount your litellm config.yaml
|
|
||||||
ports:
|
ports:
|
||||||
- "4000:4000"
|
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
||||||
environment:
|
volumes:
|
||||||
- AZURE_API_KEY=sk-123
|
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
|
||||||
|
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
|
||||||
|
command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8" ]
|
||||||
|
|
||||||
|
# ...rest of your docker-compose config if any
|
|
@ -1,5 +1,5 @@
|
||||||
# Enterprise
|
# Enterprise
|
||||||
For companies that need better security, user management and professional support
|
For companies that need SSO, user management and professional support for LiteLLM Proxy
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
||||||
|
|
|
@ -95,8 +95,8 @@ print(content)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Chat Models
|
## Chat Models
|
||||||
| Model Name | Function Call | Required OS Variables |
|
| Model Name | Function Call | Required OS Variables |
|
||||||
|------------------|--------------------------------------|-------------------------|
|
|-----------------------|--------------------------------------------------------|--------------------------------|
|
||||||
| gemini-pro | `completion('gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
|
| gemini-pro | `completion('gemini/gemini-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||||
| gemini-1.5-pro | `completion('gemini/gemini-1.5-pro', messages)` | `os.environ['GEMINI_API_KEY']` |
|
| gemini-1.5-pro-latest | `completion('gemini/gemini-1.5-pro-latest', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||||
| gemini-pro-vision | `completion('gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
|
| gemini-pro-vision | `completion('gemini/gemini-pro-vision', messages)` | `os.environ['GEMINI_API_KEY']` |
|
||||||
|
|
|
@ -25,8 +25,11 @@ All models listed here https://docs.voyageai.com/embeddings/#models-and-specific
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| voyage-2 | `embedding(model="voyage/voyage-2", input)` |
|
||||||
|
| voyage-large-2 | `embedding(model="voyage/voyage-large-2", input)` |
|
||||||
|
| voyage-law-2 | `embedding(model="voyage/voyage-law-2", input)` |
|
||||||
|
| voyage-code-2 | `embedding(model="voyage/voyage-code-2", input)` |
|
||||||
|
| voyage-lite-02-instruct | `embedding(model="voyage/voyage-lite-02-instruct", input)` |
|
||||||
| voyage-01 | `embedding(model="voyage/voyage-01", input)` |
|
| voyage-01 | `embedding(model="voyage/voyage-01", input)` |
|
||||||
| voyage-lite-01 | `embedding(model="voyage/voyage-lite-01", input)` |
|
| voyage-lite-01 | `embedding(model="voyage/voyage-lite-01", input)` |
|
||||||
| voyage-lite-01-instruct | `embedding(model="voyage/voyage-lite-01-instruct", input)` |
|
| voyage-lite-01-instruct | `embedding(model="voyage/voyage-lite-01-instruct", input)` |
|
||||||
|
|
||||||
|
|
9
docs/my-website/docs/proxy/demo.md
Normal file
9
docs/my-website/docs/proxy/demo.md
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# 🎉 Demo App
|
||||||
|
|
||||||
|
Here is a demo of the proxy. To log in pass in:
|
||||||
|
|
||||||
|
- Username: admin
|
||||||
|
- Password: sk-1234
|
||||||
|
|
||||||
|
|
||||||
|
[Demo UI](https://demo.litellm.ai/ui)
|
|
@ -666,8 +666,8 @@ services:
|
||||||
litellm:
|
litellm:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
args:
|
args:
|
||||||
target: runtime
|
target: runtime
|
||||||
image: ghcr.io/berriai/litellm:main-latest
|
image: ghcr.io/berriai/litellm:main-latest
|
||||||
ports:
|
ports:
|
||||||
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# ✨ Enterprise Features - Content Mod
|
# ✨ Enterprise Features - Content Mod, SSO
|
||||||
|
|
||||||
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
|
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
|
||||||
|
|
||||||
|
@ -12,16 +12,18 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
|
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
|
||||||
- ✅ Content Moderation with LLM Guard
|
- ✅ Content Moderation with LLM Guard
|
||||||
- ✅ Content Moderation with LlamaGuard
|
- ✅ Content Moderation with LlamaGuard
|
||||||
- ✅ Content Moderation with Google Text Moderations
|
- ✅ Content Moderation with Google Text Moderations
|
||||||
- ✅ Reject calls from Blocked User list
|
- ✅ Reject calls from Blocked User list
|
||||||
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||||
- ✅ Don't log/store specific requests (eg confidential LLM requests)
|
- ✅ Don't log/store specific requests to Langfuse, Sentry, etc. (eg confidential LLM requests)
|
||||||
- ✅ Tracking Spend for Custom Tags
|
- ✅ Tracking Spend for Custom Tags
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Content Moderation
|
## Content Moderation
|
||||||
### Content Moderation with LLM Guard
|
### Content Moderation with LLM Guard
|
||||||
|
|
||||||
|
@ -74,7 +76,7 @@ curl --location 'http://localhost:4000/key/generate' \
|
||||||
# Returns {..'key': 'my-new-key'}
|
# Returns {..'key': 'my-new-key'}
|
||||||
```
|
```
|
||||||
|
|
||||||
**2. Test it!**
|
**3. Test it!**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
@ -87,6 +89,76 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Turn on/off per request
|
||||||
|
|
||||||
|
**1. Update config**
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["llmguard_moderations"]
|
||||||
|
llm_guard_mode: "request-specific"
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Create new key**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://localhost:4000/key/generate' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"models": ["fake-openai-endpoint"],
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Returns {..'key': 'my-new-key'}
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Test it!**
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={ # pass in any provider-specific param, if not supported by openai, https://docs.litellm.ai/docs/completion/input#provider-specific-params
|
||||||
|
"metadata": {
|
||||||
|
"permissions": {
|
||||||
|
"enable_llm_guard_check": True # 👈 KEY CHANGE
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="curl" label="Curl Request">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
|
||||||
|
--data '{"model": "fake-openai-endpoint", "messages": [
|
||||||
|
{"role": "system", "content": "Be helpful"},
|
||||||
|
{"role": "user", "content": "What do you know?"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Content Moderation with LlamaGuard
|
### Content Moderation with LlamaGuard
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ Now, when you [generate keys](./virtual_keys.md) for this team-id
|
||||||
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
-H 'Authorization: Bearer sk-1234' \
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
-D '{"team_id": "ishaans-secret-project"}'
|
-d '{"team_id": "ishaans-secret-project"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
All requests made with these keys will log data to their team-specific logging.
|
All requests made with these keys will log data to their team-specific logging.
|
||||||
|
|
|
@ -108,6 +108,34 @@ general_settings:
|
||||||
litellm_jwtauth:
|
litellm_jwtauth:
|
||||||
admin_jwt_scope: "litellm-proxy-admin"
|
admin_jwt_scope: "litellm-proxy-admin"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Advanced - Spend Tracking (User / Team / Org)
|
||||||
|
|
||||||
|
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
admin_jwt_scope: "litellm-proxy-admin"
|
||||||
|
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
|
||||||
|
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
|
||||||
|
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected JWT:
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"client_id": "my-unique-team",
|
||||||
|
"sub": "my-unique-user",
|
||||||
|
"org_id": "my-unique-org"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Now litellm will automatically update the spend for the user/team/org in the db for each call.
|
||||||
|
|
||||||
### JWT Scopes
|
### JWT Scopes
|
||||||
|
|
||||||
Here's what scopes on JWT-Auth tokens look like
|
Here's what scopes on JWT-Auth tokens look like
|
||||||
|
|
|
@ -56,6 +56,9 @@ On accessing the LiteLLM UI, you will be prompted to enter your username, passwo
|
||||||
|
|
||||||
## ✨ Enterprise Features
|
## ✨ Enterprise Features
|
||||||
|
|
||||||
|
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
|
||||||
|
|
||||||
|
|
||||||
### Setup SSO/Auth for UI
|
### Setup SSO/Auth for UI
|
||||||
|
|
||||||
#### Step 1: Set upperbounds for keys
|
#### Step 1: Set upperbounds for keys
|
||||||
|
|
|
@ -95,12 +95,129 @@ print(response)
|
||||||
- `router.image_generation()` - completion calls in OpenAI `/v1/images/generations` endpoint format
|
- `router.image_generation()` - completion calls in OpenAI `/v1/images/generations` endpoint format
|
||||||
- `router.aimage_generation()` - async image generation calls
|
- `router.aimage_generation()` - async image generation calls
|
||||||
|
|
||||||
### Advanced
|
### Advanced - Routing Strategies
|
||||||
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based
|
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based
|
||||||
|
|
||||||
Router provides 4 strategies for routing your calls across multiple deployments:
|
Router provides 4 strategies for routing your calls across multiple deployments:
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
<TabItem value="usage-based-v2" label="Rate-Limit Aware v2 (ASYNC)">
|
||||||
|
|
||||||
|
**🎉 NEW** This is an async implementation of usage-based-routing.
|
||||||
|
|
||||||
|
**Filters out deployment if tpm/rpm limit exceeded** - If you pass in the deployment's tpm/rpm limits.
|
||||||
|
|
||||||
|
Routes to **deployment with lowest TPM usage** for that minute.
|
||||||
|
|
||||||
|
In production, we use Redis to track usage (TPM/RPM) across multiple deployments. This implementation uses **async redis calls** (redis.incr and redis.mget).
|
||||||
|
|
||||||
|
For Azure, your RPM = TPM/6.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="sdk">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
|
||||||
|
model_list = [{ # list of model deployments
|
||||||
|
"model_name": "gpt-3.5-turbo", # model alias
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2", # actual model name
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
}, {
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 1000,
|
||||||
|
}, {
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 1000,
|
||||||
|
}]
|
||||||
|
router = Router(model_list=model_list,
|
||||||
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
|
redis_password=os.environ["REDIS_PASSWORD"],
|
||||||
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
|
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
|
||||||
|
enable_pre_call_check=True, # enables router rate limits for concurrent calls
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="proxy">
|
||||||
|
|
||||||
|
**1. Set strategy in config**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo # model alias
|
||||||
|
litellm_params: # params for litellm completion/embedding call
|
||||||
|
model: azure/chatgpt-v-2 # actual model name
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: os.environ/AZURE_API_VERSION
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
tpm: 100000
|
||||||
|
rpm: 10000
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params: # params for litellm completion/embedding call
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
api_key: os.getenv(OPENAI_API_KEY)
|
||||||
|
tpm: 100000
|
||||||
|
rpm: 1000
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
|
||||||
|
redis_host: <your-redis-host>
|
||||||
|
redis_password: <your-redis-password>
|
||||||
|
redis_port: <your-redis-port>
|
||||||
|
enable_pre_call_check: true
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Start proxy**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Test it!**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://localhost:4000/v1/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
<TabItem value="latency-based" label="Latency-Based">
|
<TabItem value="latency-based" label="Latency-Based">
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +234,10 @@ import asyncio
|
||||||
model_list = [{ ... }]
|
model_list = [{ ... }]
|
||||||
|
|
||||||
# init router
|
# init router
|
||||||
router = Router(model_list=model_list, routing_strategy="latency-based-routing") # 👈 set routing strategy
|
router = Router(model_list=model_list,
|
||||||
|
routing_strategy="latency-based-routing",# 👈 set routing strategy
|
||||||
|
enable_pre_call_check=True, # enables router rate limits for concurrent calls
|
||||||
|
)
|
||||||
|
|
||||||
## CALL 1+2
|
## CALL 1+2
|
||||||
tasks = []
|
tasks = []
|
||||||
|
@ -257,8 +377,9 @@ router = Router(model_list=model_list,
|
||||||
redis_host=os.environ["REDIS_HOST"],
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
redis_password=os.environ["REDIS_PASSWORD"],
|
redis_password=os.environ["REDIS_PASSWORD"],
|
||||||
redis_port=os.environ["REDIS_PORT"],
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
routing_strategy="usage-based-routing")
|
routing_strategy="usage-based-routing"
|
||||||
|
enable_pre_call_check=True, # enables router rate limits for concurrent calls
|
||||||
|
)
|
||||||
|
|
||||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
@ -555,7 +676,11 @@ router = Router(model_list: Optional[list] = None,
|
||||||
|
|
||||||
## Pre-Call Checks (Context Window)
|
## Pre-Call Checks (Context Window)
|
||||||
|
|
||||||
Enable pre-call checks to filter out deployments with context window limit < messages for a call.
|
Enable pre-call checks to filter out:
|
||||||
|
1. deployments with context window limit < messages for a call.
|
||||||
|
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[
|
||||||
|
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
|
||||||
|
])`)
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
|
@ -36,6 +36,7 @@ const sidebars = {
|
||||||
label: "📖 All Endpoints (Swagger)",
|
label: "📖 All Endpoints (Swagger)",
|
||||||
href: "https://litellm-api.up.railway.app/",
|
href: "https://litellm-api.up.railway.app/",
|
||||||
},
|
},
|
||||||
|
"proxy/demo",
|
||||||
"proxy/configs",
|
"proxy/configs",
|
||||||
"proxy/reliability",
|
"proxy/reliability",
|
||||||
"proxy/users",
|
"proxy/users",
|
||||||
|
@ -163,7 +164,6 @@ const sidebars = {
|
||||||
"debugging/local_debugging",
|
"debugging/local_debugging",
|
||||||
"observability/callbacks",
|
"observability/callbacks",
|
||||||
"observability/custom_callback",
|
"observability/custom_callback",
|
||||||
"observability/lunary_integration",
|
|
||||||
"observability/langfuse_integration",
|
"observability/langfuse_integration",
|
||||||
"observability/sentry",
|
"observability/sentry",
|
||||||
"observability/promptlayer_integration",
|
"observability/promptlayer_integration",
|
||||||
|
@ -171,6 +171,7 @@ const sidebars = {
|
||||||
"observability/langsmith_integration",
|
"observability/langsmith_integration",
|
||||||
"observability/slack_integration",
|
"observability/slack_integration",
|
||||||
"observability/traceloop_integration",
|
"observability/traceloop_integration",
|
||||||
|
"observability/lunary_integration",
|
||||||
"observability/athina_integration",
|
"observability/athina_integration",
|
||||||
"observability/helicone_integration",
|
"observability/helicone_integration",
|
||||||
"observability/supabase_integration",
|
"observability/supabase_integration",
|
||||||
|
|
|
@ -95,7 +95,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:
|
||||||
if self.llm_guard_mode == "key-specific":
|
if self.llm_guard_mode == "key-specific":
|
||||||
# check if llm guard enabled for specific keys only
|
# check if llm guard enabled for specific keys only
|
||||||
self.print_verbose(
|
self.print_verbose(
|
||||||
|
@ -108,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
return True
|
return True
|
||||||
elif self.llm_guard_mode == "all":
|
elif self.llm_guard_mode == "all":
|
||||||
return True
|
return True
|
||||||
|
elif self.llm_guard_mode == "request-specific":
|
||||||
|
self.print_verbose(f"received metadata: {data.get('metadata', {})}")
|
||||||
|
metadata = data.get("metadata", {})
|
||||||
|
permissions = metadata.get("permissions", {})
|
||||||
|
if (
|
||||||
|
"enable_llm_guard_check" in permissions
|
||||||
|
and permissions["enable_llm_guard_check"] == True
|
||||||
|
):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
|
@ -126,7 +135,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
|
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
|
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data)
|
||||||
if _proceed == False:
|
if _proceed == False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,11 @@ import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
|
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
|
||||||
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
from litellm.proxy._types import (
|
||||||
|
KeyManagementSystem,
|
||||||
|
KeyManagementSettings,
|
||||||
|
LiteLLM_UpperboundKeyGenerateParams,
|
||||||
|
)
|
||||||
import httpx
|
import httpx
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
|
@ -64,7 +68,7 @@ google_moderation_confidence_threshold: Optional[float] = None
|
||||||
llamaguard_unsafe_content_categories: Optional[str] = None
|
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||||
blocked_user_list: Optional[Union[str, List]] = None
|
blocked_user_list: Optional[Union[str, List]] = None
|
||||||
banned_keywords_list: Optional[Union[str, List]] = None
|
banned_keywords_list: Optional[Union[str, List]] = None
|
||||||
llm_guard_mode: Literal["all", "key-specific"] = "all"
|
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
|
@ -172,7 +176,7 @@ dynamodb_table_name: Optional[str] = None
|
||||||
s3_callback_params: Optional[Dict] = None
|
s3_callback_params: Optional[Dict] = None
|
||||||
generic_logger_headers: Optional[Dict] = None
|
generic_logger_headers: Optional[Dict] = None
|
||||||
default_key_generate_params: Optional[Dict] = None
|
default_key_generate_params: Optional[Dict] = None
|
||||||
upperbound_key_generate_params: Optional[Dict] = None
|
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
||||||
default_user_params: Optional[Dict] = None
|
default_user_params: Optional[Dict] = None
|
||||||
default_team_settings: Optional[List] = None
|
default_team_settings: Optional[List] = None
|
||||||
max_user_budget: Optional[float] = None
|
max_user_budget: Optional[float] = None
|
||||||
|
|
|
@ -81,9 +81,30 @@ class InMemoryCache(BaseCache):
|
||||||
return cached_response
|
return cached_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def batch_get_cache(self, keys: list, **kwargs):
|
||||||
|
return_val = []
|
||||||
|
for k in keys:
|
||||||
|
val = self.get_cache(key=k, **kwargs)
|
||||||
|
return_val.append(val)
|
||||||
|
return return_val
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
return self.get_cache(key=key, **kwargs)
|
return self.get_cache(key=key, **kwargs)
|
||||||
|
|
||||||
|
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||||
|
return_val = []
|
||||||
|
for k in keys:
|
||||||
|
val = self.get_cache(key=k, **kwargs)
|
||||||
|
return_val.append(val)
|
||||||
|
return return_val
|
||||||
|
|
||||||
|
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||||
|
# get the value
|
||||||
|
init_value = await self.async_get_cache(key=key) or 0
|
||||||
|
value = init_value + value
|
||||||
|
await self.async_set_cache(key, value, **kwargs)
|
||||||
|
return value
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
self.ttl_dict.clear()
|
self.ttl_dict.clear()
|
||||||
|
@ -246,6 +267,21 @@ class RedisCache(BaseCache):
|
||||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
await self.flush_cache_buffer()
|
await self.flush_cache_buffer()
|
||||||
|
|
||||||
|
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||||
|
_redis_client = self.init_async_client()
|
||||||
|
try:
|
||||||
|
async with _redis_client as redis_client:
|
||||||
|
result = await redis_client.incr(name=key, amount=value)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||||
|
str(e),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
async def flush_cache_buffer(self):
|
async def flush_cache_buffer(self):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
|
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
|
||||||
|
@ -283,6 +319,32 @@ class RedisCache(BaseCache):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||||
|
|
||||||
|
def batch_get_cache(self, key_list) -> dict:
|
||||||
|
"""
|
||||||
|
Use Redis for bulk read operations
|
||||||
|
"""
|
||||||
|
key_value_dict = {}
|
||||||
|
try:
|
||||||
|
_keys = []
|
||||||
|
for cache_key in key_list:
|
||||||
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
|
_keys.append(cache_key)
|
||||||
|
results = self.redis_client.mget(keys=_keys)
|
||||||
|
|
||||||
|
# Associate the results back with their keys.
|
||||||
|
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||||
|
key_value_dict = dict(zip(key_list, results))
|
||||||
|
|
||||||
|
decoded_results = {
|
||||||
|
k.decode("utf-8"): self._get_cache_logic(v)
|
||||||
|
for k, v in key_value_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoded_results
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"Error occurred in pipeline read - {str(e)}")
|
||||||
|
return key_value_dict
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
key = self.check_and_fix_namespace(key=key)
|
key = self.check_and_fix_namespace(key=key)
|
||||||
|
@ -301,7 +363,7 @@ class RedisCache(BaseCache):
|
||||||
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_cache_pipeline(self, key_list) -> dict:
|
async def async_batch_get_cache(self, key_list) -> dict:
|
||||||
"""
|
"""
|
||||||
Use Redis for bulk read operations
|
Use Redis for bulk read operations
|
||||||
"""
|
"""
|
||||||
|
@ -309,14 +371,11 @@ class RedisCache(BaseCache):
|
||||||
key_value_dict = {}
|
key_value_dict = {}
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
async with redis_client.pipeline(transaction=True) as pipe:
|
_keys = []
|
||||||
# Queue the get operations in the pipeline for all keys.
|
for cache_key in key_list:
|
||||||
for cache_key in key_list:
|
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
_keys.append(cache_key)
|
||||||
pipe.get(cache_key) # Queue GET command in pipeline
|
results = await redis_client.mget(keys=_keys)
|
||||||
|
|
||||||
# Execute the pipeline and await the results.
|
|
||||||
results = await pipe.execute()
|
|
||||||
|
|
||||||
# Associate the results back with their keys.
|
# Associate the results back with their keys.
|
||||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||||
|
@ -897,6 +956,39 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
|
||||||
|
try:
|
||||||
|
result = [None for _ in range(len(keys))]
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
|
||||||
|
|
||||||
|
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||||
|
if in_memory_result is not None:
|
||||||
|
result = in_memory_result
|
||||||
|
|
||||||
|
if None in result and self.redis_cache is not None and local_only == False:
|
||||||
|
"""
|
||||||
|
- for the none values in the result
|
||||||
|
- check the redis cache
|
||||||
|
"""
|
||||||
|
sublist_keys = [
|
||||||
|
key for key, value in zip(keys, result) if value is None
|
||||||
|
]
|
||||||
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
|
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
|
||||||
|
if redis_result is not None:
|
||||||
|
# Update in-memory cache with the value from Redis
|
||||||
|
for key in redis_result:
|
||||||
|
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
|
||||||
|
|
||||||
|
for key, value in redis_result.items():
|
||||||
|
result[sublist_keys.index(key)] = value
|
||||||
|
|
||||||
|
print_verbose(f"async batch get cache: cache result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||||
# Try to fetch from in-memory cache first
|
# Try to fetch from in-memory cache first
|
||||||
try:
|
try:
|
||||||
|
@ -930,6 +1022,50 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_batch_get_cache(
|
||||||
|
self, keys: list, local_only: bool = False, **kwargs
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
result = [None for _ in range(len(keys))]
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
||||||
|
keys, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||||
|
if in_memory_result is not None:
|
||||||
|
result = in_memory_result
|
||||||
|
|
||||||
|
if None in result and self.redis_cache is not None and local_only == False:
|
||||||
|
"""
|
||||||
|
- for the none values in the result
|
||||||
|
- check the redis cache
|
||||||
|
"""
|
||||||
|
sublist_keys = [
|
||||||
|
key for key, value in zip(keys, result) if value is None
|
||||||
|
]
|
||||||
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
|
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||||
|
sublist_keys, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if redis_result is not None:
|
||||||
|
# Update in-memory cache with the value from Redis
|
||||||
|
for key in redis_result:
|
||||||
|
await self.in_memory_cache.async_set_cache(
|
||||||
|
key, redis_result[key], **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
sublist_dict = dict(zip(sublist_keys, redis_result))
|
||||||
|
|
||||||
|
for key, value in sublist_dict.items():
|
||||||
|
result[sublist_keys.index(key)] = value
|
||||||
|
|
||||||
|
print_verbose(f"async batch get cache: cache result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||||
try:
|
try:
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
|
@ -941,6 +1077,32 @@ class DualCache(BaseCache):
|
||||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_increment_cache(
|
||||||
|
self, key, value: int, local_only: bool = False, **kwargs
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Key - the key in cache
|
||||||
|
|
||||||
|
Value - int - the value you want to increment by
|
||||||
|
|
||||||
|
Returns - int - the incremented value
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result: int = value
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
result = await self.in_memory_cache.async_increment(
|
||||||
|
key, value, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.redis_cache is not None and local_only == False:
|
||||||
|
result = await self.redis_cache.async_increment(key, value, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
self.in_memory_cache.flush_cache()
|
self.in_memory_cache.flush_cache()
|
||||||
|
|
|
@ -161,7 +161,7 @@ class LangFuseLogger:
|
||||||
verbose_logger.info(f"Langfuse Layer Logging - logging success")
|
verbose_logger.info(f"Langfuse Layer Logging - logging success")
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print(f"Langfuse Layer Error - {traceback.format_exc()}")
|
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _async_log_event(
|
async def _async_log_event(
|
||||||
|
@ -190,7 +190,7 @@ class LangFuseLogger:
|
||||||
):
|
):
|
||||||
from langfuse.model import CreateTrace, CreateGeneration
|
from langfuse.model import CreateTrace, CreateGeneration
|
||||||
|
|
||||||
print(
|
verbose_logger.warning(
|
||||||
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
|
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -247,7 +247,6 @@ class LangFuseLogger:
|
||||||
|
|
||||||
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
||||||
|
|
||||||
print(f"response_obj: {response_obj}")
|
|
||||||
if supports_tags:
|
if supports_tags:
|
||||||
metadata_tags = metadata.get("tags", [])
|
metadata_tags = metadata.get("tags", [])
|
||||||
tags = metadata_tags
|
tags = metadata_tags
|
||||||
|
@ -312,13 +311,11 @@ class LangFuseLogger:
|
||||||
usage = None
|
usage = None
|
||||||
if response_obj is not None and response_obj.get("id", None) is not None:
|
if response_obj is not None and response_obj.get("id", None) is not None:
|
||||||
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
||||||
print(f"getting usage, cost={cost}")
|
|
||||||
usage = {
|
usage = {
|
||||||
"prompt_tokens": response_obj["usage"]["prompt_tokens"],
|
"prompt_tokens": response_obj["usage"]["prompt_tokens"],
|
||||||
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
||||||
"total_cost": cost if supports_costs else None,
|
"total_cost": cost if supports_costs else None,
|
||||||
}
|
}
|
||||||
print(f"constructed usage - {usage}")
|
|
||||||
generation_name = metadata.get("generation_name", None)
|
generation_name = metadata.get("generation_name", None)
|
||||||
if generation_name is None:
|
if generation_name is None:
|
||||||
# just log `litellm-{call_type}` as the generation name
|
# just log `litellm-{call_type}` as the generation name
|
||||||
|
@ -351,4 +348,4 @@ class LangFuseLogger:
|
||||||
|
|
||||||
trace.generation(**generation_params)
|
trace.generation(**generation_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Langfuse Layer Error - {traceback.format_exc()}")
|
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||||
|
|
|
@ -53,6 +53,8 @@ class LangsmithLogger:
|
||||||
value = kwargs[key]
|
value = kwargs[key]
|
||||||
if key == "start_time" or key == "end_time":
|
if key == "start_time" or key == "end_time":
|
||||||
pass
|
pass
|
||||||
|
elif type(value) == datetime.datetime:
|
||||||
|
new_kwargs[key] = value.isoformat()
|
||||||
elif type(value) != dict:
|
elif type(value) != dict:
|
||||||
new_kwargs[key] = value
|
new_kwargs[key] = value
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,8 @@ from typing import Callable, Optional, List
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,6 +16,8 @@ class AnthropicConstants(Enum):
|
||||||
HUMAN_PROMPT = "\n\nHuman: "
|
HUMAN_PROMPT = "\n\nHuman: "
|
||||||
AI_PROMPT = "\n\nAssistant: "
|
AI_PROMPT = "\n\nAssistant: "
|
||||||
|
|
||||||
|
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
|
||||||
|
|
||||||
|
|
||||||
class AnthropicError(Exception):
|
class AnthropicError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
|
@ -36,7 +39,9 @@ class AnthropicConfig:
|
||||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_tokens: Optional[int] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
max_tokens: Optional[int] = (
|
||||||
|
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||||||
|
)
|
||||||
stop_sequences: Optional[list] = None
|
stop_sequences: Optional[list] = None
|
||||||
temperature: Optional[int] = None
|
temperature: Optional[int] = None
|
||||||
top_p: Optional[int] = None
|
top_p: Optional[int] = None
|
||||||
|
@ -46,7 +51,9 @@ class AnthropicConfig:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_tokens: Optional[int] = 4096, # You can pass in a value yourself or use the default value 4096
|
max_tokens: Optional[
|
||||||
|
int
|
||||||
|
] = 4096, # You can pass in a value yourself or use the default value 4096
|
||||||
stop_sequences: Optional[list] = None,
|
stop_sequences: Optional[list] = None,
|
||||||
temperature: Optional[int] = None,
|
temperature: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
|
@ -95,121 +102,23 @@ def validate_environment(api_key, user_headers):
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
class AnthropicChatCompletion(BaseLLM):
|
||||||
model: str,
|
def __init__(self) -> None:
|
||||||
messages: list,
|
super().__init__()
|
||||||
api_base: str,
|
|
||||||
custom_prompt_dict: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
logging_obj,
|
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
headers={},
|
|
||||||
):
|
|
||||||
headers = validate_environment(api_key, headers)
|
|
||||||
_is_function_call = False
|
|
||||||
messages = copy.deepcopy(messages)
|
|
||||||
optional_params = copy.deepcopy(optional_params)
|
|
||||||
if model in custom_prompt_dict:
|
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
|
||||||
prompt = custom_prompt(
|
|
||||||
role_dict=model_prompt_details["roles"],
|
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Separate system prompt from rest of message
|
|
||||||
system_prompt_indices = []
|
|
||||||
system_prompt = ""
|
|
||||||
for idx, message in enumerate(messages):
|
|
||||||
if message["role"] == "system":
|
|
||||||
system_prompt += message["content"]
|
|
||||||
system_prompt_indices.append(idx)
|
|
||||||
if len(system_prompt_indices) > 0:
|
|
||||||
for idx in reversed(system_prompt_indices):
|
|
||||||
messages.pop(idx)
|
|
||||||
if len(system_prompt) > 0:
|
|
||||||
optional_params["system"] = system_prompt
|
|
||||||
# Format rest of message according to anthropic guidelines
|
|
||||||
try:
|
|
||||||
messages = prompt_factory(
|
|
||||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise AnthropicError(status_code=400, message=str(e))
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.AnthropicConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
## Handle Tool Calling
|
|
||||||
if "tools" in optional_params:
|
|
||||||
_is_function_call = True
|
|
||||||
headers["anthropic-beta"] = "tools-2024-04-04"
|
|
||||||
|
|
||||||
anthropic_tools = []
|
|
||||||
for tool in optional_params["tools"]:
|
|
||||||
new_tool = tool["function"]
|
|
||||||
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
|
||||||
anthropic_tools.append(new_tool)
|
|
||||||
|
|
||||||
optional_params["tools"] = anthropic_tools
|
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
**optional_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"api_base": api_base,
|
|
||||||
"headers": headers,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
|
||||||
## COMPLETION CALL
|
|
||||||
if (
|
|
||||||
stream and not _is_function_call
|
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
|
||||||
print_verbose("makes anthropic streaming POST request")
|
|
||||||
data["stream"] = stream
|
|
||||||
response = requests.post(
|
|
||||||
api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise AnthropicError(
|
|
||||||
status_code=response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.iter_lines()
|
|
||||||
else:
|
|
||||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise AnthropicError(
|
|
||||||
status_code=response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
response,
|
||||||
|
model_response,
|
||||||
|
_is_function_call,
|
||||||
|
stream,
|
||||||
|
logging_obj,
|
||||||
|
api_key,
|
||||||
|
data,
|
||||||
|
messages,
|
||||||
|
print_verbose,
|
||||||
|
):
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -327,6 +236,272 @@ def completion(
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
async def acompletion_stream_function(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
_is_function_call,
|
||||||
|
data=None,
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
):
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise AnthropicError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = response.aiter_lines()
|
||||||
|
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
|
||||||
|
async def acompletion_function(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
_is_function_call,
|
||||||
|
data=None,
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
):
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
_is_function_call=_is_function_call,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params=None,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
):
|
||||||
|
headers = validate_environment(api_key, headers)
|
||||||
|
_is_function_call = False
|
||||||
|
messages = copy.deepcopy(messages)
|
||||||
|
optional_params = copy.deepcopy(optional_params)
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_prompt = ""
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
system_prompt += message["content"]
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
if len(system_prompt) > 0:
|
||||||
|
optional_params["system"] = system_prompt
|
||||||
|
# Format rest of message according to anthropic guidelines
|
||||||
|
try:
|
||||||
|
messages = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise AnthropicError(status_code=400, message=str(e))
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.AnthropicConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
## Handle Tool Calling
|
||||||
|
if "tools" in optional_params:
|
||||||
|
_is_function_call = True
|
||||||
|
headers["anthropic-beta"] = "tools-2024-04-04"
|
||||||
|
|
||||||
|
anthropic_tools = []
|
||||||
|
for tool in optional_params["tools"]:
|
||||||
|
new_tool = tool["function"]
|
||||||
|
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
|
||||||
|
anthropic_tools.append(new_tool)
|
||||||
|
|
||||||
|
optional_params["tools"] = anthropic_tools
|
||||||
|
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": api_base,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||||
|
if acompletion == True:
|
||||||
|
if (
|
||||||
|
stream and not _is_function_call
|
||||||
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
|
print_verbose("makes async anthropic streaming POST request")
|
||||||
|
data["stream"] = stream
|
||||||
|
return self.acompletion_stream_function(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
_is_function_call=_is_function_call,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.acompletion_function(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
_is_function_call=_is_function_call,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
## COMPLETION CALL
|
||||||
|
if (
|
||||||
|
stream and not _is_function_call
|
||||||
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
|
print_verbose("makes anthropic streaming POST request")
|
||||||
|
data["stream"] = stream
|
||||||
|
response = requests.post(
|
||||||
|
api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise AnthropicError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = response.iter_lines()
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise AnthropicError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
_is_function_call=_is_function_call,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embedding(self):
|
||||||
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, model_response):
|
def __init__(self, model_response):
|
||||||
|
@ -352,8 +527,3 @@ class ModelResponseIterator:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
self.is_done = True
|
self.is_done = True
|
||||||
return self.model_response
|
return self.model_response
|
||||||
|
|
||||||
|
|
||||||
def embedding():
|
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
|
||||||
pass
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from enum import Enum
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -162,8 +162,15 @@ def completion(
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
)
|
)
|
||||||
|
completion_stream = response.iter_lines()
|
||||||
|
stream_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return stream_response
|
||||||
|
|
||||||
return response.iter_lines()
|
|
||||||
else:
|
else:
|
||||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
|
|
@ -1,21 +1,34 @@
|
||||||
import httpx, asyncio
|
import httpx, asyncio
|
||||||
from typing import Optional
|
from typing import Optional, Union, Mapping, Any
|
||||||
|
|
||||||
|
# https://www.python-httpx.org/advanced/timeouts
|
||||||
|
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||||
|
|
||||||
|
|
||||||
class AsyncHTTPHandler:
|
class AsyncHTTPHandler:
|
||||||
def __init__(self, concurrent_limit=1000):
|
def __init__(
|
||||||
|
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
||||||
|
):
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
self.client = httpx.AsyncClient(
|
self.client = httpx.AsyncClient(
|
||||||
|
timeout=timeout,
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=concurrent_limit,
|
max_connections=concurrent_limit,
|
||||||
max_keepalive_connections=concurrent_limit,
|
max_keepalive_connections=concurrent_limit,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
# Close the client when you're done with it
|
# Close the client when you're done with it
|
||||||
await self.client.aclose()
|
await self.client.aclose()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
async def __aexit__(self):
|
||||||
|
# close the client when exiting
|
||||||
|
await self.client.aclose()
|
||||||
|
|
||||||
async def get(
|
async def get(
|
||||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||||
):
|
):
|
||||||
|
@ -25,12 +38,15 @@ class AsyncHTTPHandler:
|
||||||
async def post(
|
async def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[dict] = None,
|
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
response = await self.client.post(
|
response = await self.client.post(
|
||||||
url, data=data, params=params, headers=headers
|
url,
|
||||||
|
data=data, # type: ignore
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,8 @@ from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import sys, httpx
|
import sys, httpx
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
|
|
||||||
class GeminiError(Exception):
|
class GeminiError(Exception):
|
||||||
|
@ -103,6 +104,13 @@ class TextStreamer:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def supports_system_instruction():
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
gemini_pkg_version = Version(genai.__version__)
|
||||||
|
return gemini_pkg_version >= Version("0.5.0")
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
@ -124,7 +132,7 @@ def completion(
|
||||||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||||
)
|
)
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
|
system_prompt = ""
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
@ -135,6 +143,7 @@ def completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
system_prompt, messages = get_system_prompt(messages=messages)
|
||||||
prompt = prompt_factory(
|
prompt = prompt_factory(
|
||||||
model=model, messages=messages, custom_llm_provider="gemini"
|
model=model, messages=messages, custom_llm_provider="gemini"
|
||||||
)
|
)
|
||||||
|
@ -162,11 +171,20 @@ def completion(
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={"complete_input_dict": {"inference_params": inference_params}},
|
additional_args={
|
||||||
|
"complete_input_dict": {
|
||||||
|
"inference_params": inference_params,
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
_model = genai.GenerativeModel(f"models/{model}")
|
_params = {"model_name": "models/{}".format(model)}
|
||||||
|
_system_instruction = supports_system_instruction()
|
||||||
|
if _system_instruction and len(system_prompt) > 0:
|
||||||
|
_params["system_instruction"] = system_prompt
|
||||||
|
_model = genai.GenerativeModel(**_params)
|
||||||
if stream == True:
|
if stream == True:
|
||||||
if acompletion == True:
|
if acompletion == True:
|
||||||
|
|
||||||
|
@ -213,11 +231,12 @@ def completion(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = _model.generate_content(
|
params = {
|
||||||
contents=prompt,
|
"contents": prompt,
|
||||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
"generation_config": genai.types.GenerationConfig(**inference_params),
|
||||||
safety_settings=safety_settings,
|
"safety_settings": safety_settings,
|
||||||
)
|
}
|
||||||
|
response = _model.generate_content(**params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise GeminiError(
|
raise GeminiError(
|
||||||
message=str(e),
|
message=str(e),
|
||||||
|
|
|
@ -254,7 +254,7 @@ def get_ollama_response(
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = "ollama/" + model
|
model_response["model"] = "ollama/" + model
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore
|
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore
|
||||||
completion_tokens = response_json["eval_count"]
|
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
||||||
model_response["usage"] = litellm.Usage(
|
model_response["usage"] = litellm.Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
@ -356,7 +356,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = "ollama/" + data["model"]
|
model_response["model"] = "ollama/" + data["model"]
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore
|
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore
|
||||||
completion_tokens = response_json["eval_count"]
|
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
||||||
model_response["usage"] = litellm.Usage(
|
model_response["usage"] = litellm.Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, traceback
|
import requests, traceback
|
||||||
import json, re, xml.etree.ElementTree as ET
|
import json, re, xml.etree.ElementTree as ET
|
||||||
from jinja2 import Template, exceptions, Environment, meta
|
from jinja2 import Template, exceptions, meta, BaseLoader
|
||||||
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
import imghdr, base64
|
|
||||||
from typing import List
|
from typing import List
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
@ -219,6 +219,15 @@ def phind_codellama_pt(messages):
|
||||||
|
|
||||||
|
|
||||||
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
|
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
|
||||||
|
# Define Jinja2 environment
|
||||||
|
env = ImmutableSandboxedEnvironment()
|
||||||
|
|
||||||
|
def raise_exception(message):
|
||||||
|
raise Exception(f"Error message - {message}")
|
||||||
|
|
||||||
|
# Create a template object from the template text
|
||||||
|
env.globals["raise_exception"] = raise_exception
|
||||||
|
|
||||||
## get the tokenizer config from huggingface
|
## get the tokenizer config from huggingface
|
||||||
bos_token = ""
|
bos_token = ""
|
||||||
eos_token = ""
|
eos_token = ""
|
||||||
|
@ -249,12 +258,6 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
eos_token = tokenizer_config["eos_token"]
|
eos_token = tokenizer_config["eos_token"]
|
||||||
chat_template = tokenizer_config["chat_template"]
|
chat_template = tokenizer_config["chat_template"]
|
||||||
|
|
||||||
def raise_exception(message):
|
|
||||||
raise Exception(f"Error message - {message}")
|
|
||||||
|
|
||||||
# Create a template object from the template text
|
|
||||||
env = Environment()
|
|
||||||
env.globals["raise_exception"] = raise_exception
|
|
||||||
try:
|
try:
|
||||||
template = env.from_string(chat_template)
|
template = env.from_string(chat_template)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -959,7 +962,20 @@ def parse_xml_params(xml_content, json_schema: Optional[dict] = None):
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
###
|
### GEMINI HELPER FUNCTIONS ###
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_prompt(messages):
|
||||||
|
system_prompt_indices = []
|
||||||
|
system_prompt = ""
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
system_prompt += message["content"]
|
||||||
|
system_prompt_indices.append(idx)
|
||||||
|
if len(system_prompt_indices) > 0:
|
||||||
|
for idx in reversed(system_prompt_indices):
|
||||||
|
messages.pop(idx)
|
||||||
|
return system_prompt, messages
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_message_to_cohere_tool_result(message):
|
def convert_openai_message_to_cohere_tool_result(message):
|
||||||
|
|
|
@ -332,9 +332,12 @@ def completion(
|
||||||
model_response["choices"][0]["message"]["content"] = result
|
model_response["choices"][0]["message"]["content"] = result
|
||||||
|
|
||||||
# Calculate usage
|
# Calculate usage
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||||
completion_tokens = len(
|
completion_tokens = len(
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
encoding.encode(
|
||||||
|
model_response["choices"][0]["message"].get("content", ""),
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
model_response["model"] = "replicate/" + model
|
model_response["model"] = "replicate/" + model
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
|
|
|
@ -3,10 +3,10 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union, List
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
import httpx
|
import httpx, inspect
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -25,6 +25,7 @@ class VertexAIError(Exception):
|
||||||
class VertexAIConfig:
|
class VertexAIConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
|
|
||||||
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
||||||
|
|
||||||
|
@ -36,6 +37,12 @@ class VertexAIConfig:
|
||||||
|
|
||||||
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||||
|
|
||||||
|
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
|
||||||
|
|
||||||
|
- `candidate_count` (int): Number of generated responses to return.
|
||||||
|
|
||||||
|
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||||
|
|
||||||
Note: Please make sure to modify the default parameters as required for your use case.
|
Note: Please make sure to modify the default parameters as required for your use case.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -43,6 +50,9 @@ class VertexAIConfig:
|
||||||
max_output_tokens: Optional[int] = None
|
max_output_tokens: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
|
response_mime_type: Optional[str] = None
|
||||||
|
candidate_count: Optional[int] = None
|
||||||
|
stop_sequences: Optional[list] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -50,6 +60,9 @@ class VertexAIConfig:
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[list] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
@ -295,6 +308,42 @@ def completion(
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
||||||
|
|
||||||
|
if "response_mime_type" in args_spec.args:
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
response_mime_type=response_mime_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
)
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||||
|
@ -417,7 +466,7 @@ def completion(
|
||||||
return async_completion(**data)
|
return async_completion(**data)
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
|
@ -436,7 +485,7 @@ def completion(
|
||||||
|
|
||||||
model_response = llm_model.generate_content(
|
model_response = llm_model.generate_content(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
stream=True,
|
stream=True,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -458,7 +507,7 @@ def completion(
|
||||||
## LLM Call
|
## LLM Call
|
||||||
response = llm_model.generate_content(
|
response = llm_model.generate_content(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
@ -698,6 +747,43 @@ async def async_completion(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
||||||
|
|
||||||
|
if "response_mime_type" in args_spec.args:
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
response_mime_type=response_mime_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
)
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
|
@ -721,7 +807,7 @@ async def async_completion(
|
||||||
## LLM Call
|
## LLM Call
|
||||||
response = await llm_model._generate_content_async(
|
response = await llm_model._generate_content_async(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -906,6 +992,43 @@ async def async_streaming(
|
||||||
Add support for async streaming calls for gemini-pro
|
Add support for async streaming calls for gemini-pro
|
||||||
"""
|
"""
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
||||||
|
|
||||||
|
if "response_mime_type" in args_spec.args:
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
response_mime_type=response_mime_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
)
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
stream = optional_params.pop("stream")
|
stream = optional_params.pop("stream")
|
||||||
|
@ -927,7 +1050,7 @@ async def async_streaming(
|
||||||
|
|
||||||
response = await llm_model._generate_content_streaming_async(
|
response = await llm_model._generate_content_streaming_async(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
optional_params["stream"] = True
|
optional_params["stream"] = True
|
||||||
|
|
|
@ -39,7 +39,6 @@ from litellm.utils import (
|
||||||
get_optional_params_image_gen,
|
get_optional_params_image_gen,
|
||||||
)
|
)
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic,
|
|
||||||
anthropic_text,
|
anthropic_text,
|
||||||
together_ai,
|
together_ai,
|
||||||
ai21,
|
ai21,
|
||||||
|
@ -68,6 +67,7 @@ from .llms import (
|
||||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
from .llms.azure import AzureChatCompletion
|
from .llms.azure import AzureChatCompletion
|
||||||
from .llms.azure_text import AzureTextCompletion
|
from .llms.azure_text import AzureTextCompletion
|
||||||
|
from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
@ -99,6 +99,7 @@ from litellm.utils import (
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
openai_chat_completions = OpenAIChatCompletion()
|
openai_chat_completions = OpenAIChatCompletion()
|
||||||
openai_text_completions = OpenAITextCompletion()
|
openai_text_completions = OpenAITextCompletion()
|
||||||
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
azure_chat_completions = AzureChatCompletion()
|
azure_chat_completions = AzureChatCompletion()
|
||||||
azure_text_completions = AzureTextCompletion()
|
azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
|
@ -304,6 +305,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
or custom_llm_provider == "gemini"
|
or custom_llm_provider == "gemini"
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
|
or custom_llm_provider == "anthropic"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -315,6 +317,14 @@ async def acompletion(
|
||||||
response = await init_response
|
response = await init_response
|
||||||
else:
|
else:
|
||||||
response = init_response # type: ignore
|
response = init_response # type: ignore
|
||||||
|
|
||||||
|
if custom_llm_provider == "text-completion-openai" and isinstance(
|
||||||
|
response, TextCompletionResponse
|
||||||
|
):
|
||||||
|
response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
|
||||||
|
response_object=response,
|
||||||
|
model_response_object=litellm.ModelResponse(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Call the synchronous function using run_in_executor
|
# Call the synchronous function using run_in_executor
|
||||||
response = await loop.run_in_executor(None, func_with_context) # type: ignore
|
response = await loop.run_in_executor(None, func_with_context) # type: ignore
|
||||||
|
@ -1180,10 +1190,11 @@ def completion(
|
||||||
or get_secret("ANTHROPIC_API_BASE")
|
or get_secret("ANTHROPIC_API_BASE")
|
||||||
or "https://api.anthropic.com/v1/messages"
|
or "https://api.anthropic.com/v1/messages"
|
||||||
)
|
)
|
||||||
response = anthropic.completion(
|
response = anthropic_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
@ -1195,19 +1206,6 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
"stream" in optional_params
|
|
||||||
and optional_params["stream"] == True
|
|
||||||
and not isinstance(response, CustomStreamWrapper)
|
|
||||||
):
|
|
||||||
# don't try to access stream object,
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
response,
|
|
||||||
model,
|
|
||||||
custom_llm_provider="anthropic",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
@ -3786,6 +3784,9 @@ async def ahealth_check(
|
||||||
|
|
||||||
api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE")
|
api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE")
|
||||||
|
|
||||||
|
if custom_llm_provider == "text-completion-openai":
|
||||||
|
mode = "completion"
|
||||||
|
|
||||||
response = await openai_chat_completions.ahealth_check(
|
response = await openai_chat_completions.ahealth_check(
|
||||||
model=model,
|
model=model,
|
||||||
messages=model_params.get(
|
messages=model_params.get(
|
||||||
|
@ -3819,11 +3820,15 @@ async def ahealth_check(
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
stack_trace = traceback.format_exc()
|
||||||
|
if isinstance(stack_trace, str):
|
||||||
|
stack_trace = stack_trace[:1000]
|
||||||
if model not in litellm.model_cost and mode is None:
|
if model not in litellm.model_cost and mode is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
|
"Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
|
||||||
)
|
)
|
||||||
return {"error": f"{str(e)}"}
|
error_to_return = str(e) + " stack trace: " + stack_trace
|
||||||
|
return {"error": error_to_return}
|
||||||
|
|
||||||
|
|
||||||
####### HELPER FUNCTIONS ################
|
####### HELPER FUNCTIONS ################
|
||||||
|
|
|
@ -66,6 +66,28 @@
|
||||||
"litellm_provider": "openai",
|
"litellm_provider": "openai",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"gpt-4-turbo": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true
|
||||||
|
},
|
||||||
|
"gpt-4-turbo-2024-04-09": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true
|
||||||
|
},
|
||||||
"gpt-4-1106-preview": {
|
"gpt-4-1106-preview": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -948,6 +970,28 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
|
"gemini-1.0-pro-001": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_input_tokens": 32760,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.00000025,
|
||||||
|
"output_cost_per_token": 0.0000005,
|
||||||
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
|
},
|
||||||
|
"gemini-1.0-pro-002": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_input_tokens": 32760,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.00000025,
|
||||||
|
"output_cost_per_token": 0.0000005,
|
||||||
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
|
},
|
||||||
"gemini-1.5-pro": {
|
"gemini-1.5-pro": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 1000000,
|
"max_input_tokens": 1000000,
|
||||||
|
@ -970,6 +1014,17 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
|
"gemini-1.5-pro-preview-0409": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_input_tokens": 1000000,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
|
},
|
||||||
"gemini-experimental": {
|
"gemini-experimental": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 1000000,
|
"max_input_tokens": 1000000,
|
||||||
|
@ -2808,6 +2863,46 @@
|
||||||
"output_cost_per_token": 0.000000,
|
"output_cost_per_token": 0.000000,
|
||||||
"litellm_provider": "voyage",
|
"litellm_provider": "voyage",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-large-2": {
|
||||||
|
"max_tokens": 16000,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"input_cost_per_token": 0.00000012,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-law-2": {
|
||||||
|
"max_tokens": 16000,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"input_cost_per_token": 0.00000012,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-code-2": {
|
||||||
|
"max_tokens": 16000,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"input_cost_per_token": 0.00000012,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-2": {
|
||||||
|
"max_tokens": 4000,
|
||||||
|
"max_input_tokens": 4000,
|
||||||
|
"input_cost_per_token": 0.0000001,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-lite-02-instruct": {
|
||||||
|
"max_tokens": 4000,
|
||||||
|
"max_input_tokens": 4000,
|
||||||
|
"input_cost_per_token": 0.0000001,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1 +1 @@
|
||||||
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/04eb0ce8764f86fe.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();
|
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/a282d1bfd6ed4df8.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1 +1 @@
|
||||||
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-68f14392aea51f63.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-a507ee9e75a3be72.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-589b47e7a69d316f.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-096338c8e1915716.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-68f14392aea51f63.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/04eb0ce8764f86fe.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[46502,[\"253\",\"static/chunks/253-8ab6133ad5f92675.js\",\"931\",\"static/chunks/app/page-a485c9c659128852.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/04eb0ce8764f86fe.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"KnyD0lgLk9_a0erHwSSu-\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-11b043d6a7ef78fa.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-a507ee9e75a3be72.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-589b47e7a69d316f.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-096338c8e1915716.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-11b043d6a7ef78fa.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/a282d1bfd6ed4df8.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[29306,[\"823\",\"static/chunks/823-2ada48e2e6a5ab39.js\",\"931\",\"static/chunks/app/page-e16bcf8bdc356530.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/a282d1bfd6ed4df8.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"BNBzATtnAelV8BpmzRdfL\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
|
@ -1,7 +1,7 @@
|
||||||
2:I[77831,[],""]
|
2:I[77831,[],""]
|
||||||
3:I[46502,["253","static/chunks/253-8ab6133ad5f92675.js","931","static/chunks/app/page-a485c9c659128852.js"],""]
|
3:I[29306,["823","static/chunks/823-2ada48e2e6a5ab39.js","931","static/chunks/app/page-e16bcf8bdc356530.js"],""]
|
||||||
4:I[5613,[],""]
|
4:I[5613,[],""]
|
||||||
5:I[31778,[],""]
|
5:I[31778,[],""]
|
||||||
0:["KnyD0lgLk9_a0erHwSSu-",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/04eb0ce8764f86fe.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
0:["BNBzATtnAelV8BpmzRdfL",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/a282d1bfd6ed4df8.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
||||||
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
|
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
|
||||||
1:null
|
1:null
|
||||||
|
|
|
@ -5,6 +5,7 @@ model_list:
|
||||||
api_key: my-fake-key
|
api_key: my-fake-key
|
||||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||||
stream_timeout: 0.001
|
stream_timeout: 0.001
|
||||||
|
rpm: 10
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
@ -12,30 +13,38 @@ model_list:
|
||||||
api_version: "2023-07-01-preview"
|
api_version: "2023-07-01-preview"
|
||||||
stream_timeout: 0.001
|
stream_timeout: 0.001
|
||||||
model_name: azure-gpt-3.5
|
model_name: azure-gpt-3.5
|
||||||
|
# - model_name: text-embedding-ada-002
|
||||||
|
# litellm_params:
|
||||||
|
# model: text-embedding-ada-002
|
||||||
|
# api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: gpt-instruct
|
- model_name: gpt-instruct
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo-instruct
|
model: text-completion-openai/gpt-3.5-turbo-instruct
|
||||||
# api_key: my-fake-key
|
# api_key: my-fake-key
|
||||||
# api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
# api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["prometheus"]
|
success_callback: ["prometheus"]
|
||||||
|
upperbound_key_generate_params:
|
||||||
# litellm_settings:
|
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
||||||
# drop_params: True
|
|
||||||
# max_budget: 800021
|
router_settings:
|
||||||
# budget_duration: 30d
|
routing_strategy: usage-based-routing-v2
|
||||||
# # cache: true
|
redis_host: os.environ/REDIS_HOST
|
||||||
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
|
redis_port: os.environ/REDIS_PORT
|
||||||
|
enable_pre_call_checks: True
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
allow_user_auth: true
|
allow_user_auth: true
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
store_model_in_db: True
|
# store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
|
||||||
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
litellm_jwtauth:
|
litellm_jwtauth:
|
||||||
admin_jwt_scope: "litellm_proxy_admin"
|
admin_jwt_scope: "litellm_proxy_admin"
|
||||||
public_key_ttl: 600
|
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
|
||||||
|
user_id_jwt_field: "sub"
|
||||||
|
org_id_jwt_field: "azp"
|
|
@ -38,6 +38,18 @@ class LiteLLMBase(BaseModel):
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
|
||||||
|
"""
|
||||||
|
Set default upperbound to max budget a key called via `/key/generate` can be.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_budget: Optional[float] = None
|
||||||
|
budget_duration: Optional[str] = None
|
||||||
|
max_parallel_requests: Optional[int] = None
|
||||||
|
tpm_limit: Optional[int] = None
|
||||||
|
rpm_limit: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMRoutes(enum.Enum):
|
class LiteLLMRoutes(enum.Enum):
|
||||||
openai_routes: List = [ # chat completions
|
openai_routes: List = [ # chat completions
|
||||||
"/openai/deployments/{model}/chat/completions",
|
"/openai/deployments/{model}/chat/completions",
|
||||||
|
@ -112,7 +124,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
- team_jwt_scope: The JWT scope required for proxy team roles.
|
- team_jwt_scope: The JWT scope required for proxy team roles.
|
||||||
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
||||||
- team_allowed_routes: list of allowed routes for proxy team roles.
|
- team_allowed_routes: list of allowed routes for proxy team roles.
|
||||||
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking.
|
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
|
||||||
|
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
|
||||||
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
||||||
|
|
||||||
See `auth_checks.py` for the specific routes
|
See `auth_checks.py` for the specific routes
|
||||||
|
@ -127,7 +140,9 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
team_allowed_routes: List[
|
team_allowed_routes: List[
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
] = ["openai_routes", "info_routes"]
|
] = ["openai_routes", "info_routes"]
|
||||||
end_user_id_jwt_field: Optional[str] = "sub"
|
org_id_jwt_field: Optional[str] = None
|
||||||
|
user_id_jwt_field: Optional[str] = None
|
||||||
|
end_user_id_jwt_field: Optional[str] = None
|
||||||
public_key_ttl: float = 600
|
public_key_ttl: float = 600
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
@ -363,6 +378,8 @@ class NewUserRequest(GenerateKeyRequest):
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
user_role: Optional[str] = None
|
user_role: Optional[str] = None
|
||||||
|
teams: Optional[list] = None
|
||||||
|
organization_id: Optional[str] = None
|
||||||
auto_create_key: bool = (
|
auto_create_key: bool = (
|
||||||
True # flag used for returning a key as part of the /user/new response
|
True # flag used for returning a key as part of the /user/new response
|
||||||
)
|
)
|
||||||
|
@ -498,6 +515,7 @@ class LiteLLM_BudgetTable(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||||
|
organization_id: Optional[str] = None
|
||||||
organization_alias: str
|
organization_alias: str
|
||||||
models: List = []
|
models: List = []
|
||||||
budget_id: Optional[str] = None
|
budget_id: Optional[str] = None
|
||||||
|
@ -506,6 +524,7 @@ class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||||
class LiteLLM_OrganizationTable(LiteLLMBase):
|
class LiteLLM_OrganizationTable(LiteLLMBase):
|
||||||
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
|
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
|
||||||
|
|
||||||
|
organization_id: Optional[str] = None
|
||||||
organization_alias: Optional[str] = None
|
organization_alias: Optional[str] = None
|
||||||
budget_id: str
|
budget_id: str
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
@ -690,6 +709,8 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
soft_budget_cooldown: bool = False
|
soft_budget_cooldown: bool = False
|
||||||
litellm_budget_table: Optional[dict] = None
|
litellm_budget_table: Optional[dict] = None
|
||||||
|
|
||||||
|
org_id: Optional[str] = None # org id for a given key
|
||||||
|
|
||||||
# hidden params used for parallel request limiting, not required to create a token
|
# hidden params used for parallel request limiting, not required to create a token
|
||||||
user_id_rate_limits: Optional[dict] = None
|
user_id_rate_limits: Optional[dict] = None
|
||||||
team_id_rate_limits: Optional[dict] = None
|
team_id_rate_limits: Optional[dict] = None
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
LiteLLMRoutes,
|
LiteLLMRoutes,
|
||||||
|
LiteLLM_OrganizationTable,
|
||||||
)
|
)
|
||||||
from typing import Optional, Literal, Union
|
from typing import Optional, Literal, Union
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
@ -26,6 +27,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
|
||||||
def common_checks(
|
def common_checks(
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
team_object: LiteLLM_TeamTable,
|
team_object: LiteLLM_TeamTable,
|
||||||
|
user_object: Optional[LiteLLM_UserTable],
|
||||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||||
global_proxy_spend: Optional[float],
|
global_proxy_spend: Optional[float],
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
|
@ -37,7 +39,8 @@ def common_checks(
|
||||||
1. If team is blocked
|
1. If team is blocked
|
||||||
2. If team can call model
|
2. If team can call model
|
||||||
3. If team is in budget
|
3. If team is in budget
|
||||||
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
5. If user passed in (JWT or key.user_id) - is in budget
|
||||||
|
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||||
"""
|
"""
|
||||||
|
@ -69,14 +72,20 @@ def common_checks(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
||||||
)
|
)
|
||||||
# 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
if user_object is not None and user_object.max_budget is not None:
|
||||||
|
user_budget = user_object.max_budget
|
||||||
|
if user_budget > user_object.spend:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}"
|
||||||
|
)
|
||||||
|
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||||
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
||||||
if end_user_budget is not None and end_user_object.spend > end_user_budget:
|
if end_user_budget is not None and end_user_object.spend > end_user_budget:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
||||||
)
|
)
|
||||||
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
if (
|
if (
|
||||||
general_settings.get("enforce_user_param", None) is not None
|
general_settings.get("enforce_user_param", None) is not None
|
||||||
and general_settings["enforce_user_param"] == True
|
and general_settings["enforce_user_param"] == True
|
||||||
|
@ -85,7 +94,7 @@ def common_checks(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||||
)
|
)
|
||||||
# 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||||
if litellm.max_budget > 0 and global_proxy_spend is not None:
|
if litellm.max_budget > 0 and global_proxy_spend is not None:
|
||||||
if global_proxy_spend > litellm.max_budget:
|
if global_proxy_spend > litellm.max_budget:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -204,19 +213,24 @@ async def get_end_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
async def get_user_object(
|
||||||
|
user_id: str,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
"""
|
"""
|
||||||
- Check if user id in proxy User Table
|
- Check if user id in proxy User Table
|
||||||
- if valid, return LiteLLM_UserTable object with defined limits
|
- if valid, return LiteLLM_UserTable object with defined limits
|
||||||
- if not, then raise an error
|
- if not, then raise an error
|
||||||
"""
|
"""
|
||||||
if self.prisma_client is None:
|
if prisma_client is None:
|
||||||
raise Exception(
|
raise Exception("No db connected")
|
||||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
|
||||||
)
|
if user_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
|
||||||
if cached_user_obj is not None:
|
if cached_user_obj is not None:
|
||||||
if isinstance(cached_user_obj, dict):
|
if isinstance(cached_user_obj, dict):
|
||||||
return LiteLLM_UserTable(**cached_user_obj)
|
return LiteLLM_UserTable(**cached_user_obj)
|
||||||
|
@ -224,7 +238,7 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||||
return cached_user_obj
|
return cached_user_obj
|
||||||
# else, check db
|
# else, check db
|
||||||
try:
|
try:
|
||||||
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
where={"user_id": user_id}
|
where={"user_id": user_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -232,9 +246,9 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
return LiteLLM_UserTable(**response.dict())
|
return LiteLLM_UserTable(**response.dict())
|
||||||
except Exception as e:
|
except Exception as e: # if end-user not in db
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -274,3 +288,41 @@ async def get_team_object(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_org_object(
|
||||||
|
org_id: str,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
- Check if org id in proxy Org Table
|
||||||
|
- if valid, return LiteLLM_OrganizationTable object
|
||||||
|
- if not, then raise an error
|
||||||
|
"""
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if in cache
|
||||||
|
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
|
||||||
|
if cached_org_obj is not None:
|
||||||
|
if isinstance(cached_org_obj, dict):
|
||||||
|
return cached_org_obj
|
||||||
|
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
|
||||||
|
return cached_org_obj
|
||||||
|
# else, check db
|
||||||
|
try:
|
||||||
|
response = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||||
|
where={"organization_id": org_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
|
||||||
|
)
|
||||||
|
|
|
@ -74,6 +74,26 @@ class JWTHandler:
|
||||||
team_id = default_value
|
team_id = default_value
|
||||||
return team_id
|
return team_id
|
||||||
|
|
||||||
|
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||||
|
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
|
||||||
|
else:
|
||||||
|
user_id = None
|
||||||
|
except KeyError:
|
||||||
|
user_id = default_value
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
||||||
|
org_id = token[self.litellm_jwtauth.org_id_jwt_field]
|
||||||
|
else:
|
||||||
|
org_id = None
|
||||||
|
except KeyError:
|
||||||
|
org_id = default_value
|
||||||
|
return org_id
|
||||||
|
|
||||||
def get_scopes(self, token: dict) -> list:
|
def get_scopes(self, token: dict) -> list:
|
||||||
try:
|
try:
|
||||||
if isinstance(token["scope"], str):
|
if isinstance(token["scope"], str):
|
||||||
|
@ -101,7 +121,11 @@ class JWTHandler:
|
||||||
if cached_keys is None:
|
if cached_keys is None:
|
||||||
response = await self.http_handler.get(keys_url)
|
response = await self.http_handler.get(keys_url)
|
||||||
|
|
||||||
keys = response.json()["keys"]
|
response_json = response.json()
|
||||||
|
if "keys" in response_json:
|
||||||
|
keys = response.json()["keys"]
|
||||||
|
else:
|
||||||
|
keys = response_json
|
||||||
|
|
||||||
await self.user_api_key_cache.async_set_cache(
|
await self.user_api_key_cache.async_set_cache(
|
||||||
key="litellm_jwt_auth_keys",
|
key="litellm_jwt_auth_keys",
|
||||||
|
|
|
@ -79,7 +79,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
||||||
self.print_verbose(f"redis keys: {keys}")
|
self.print_verbose(f"redis keys: {keys}")
|
||||||
if len(keys) > 0:
|
if len(keys) > 0:
|
||||||
key_value_dict = (
|
key_value_dict = (
|
||||||
await litellm.cache.cache.async_get_cache_pipeline(
|
await litellm.cache.cache.async_batch_get_cache(
|
||||||
key_list=keys
|
key_list=keys
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -425,9 +425,10 @@ def run_server(
|
||||||
)
|
)
|
||||||
|
|
||||||
proxy_config = ProxyConfig()
|
proxy_config = ProxyConfig()
|
||||||
_, _, general_settings = asyncio.run(
|
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
|
||||||
proxy_config.load_config(router=None, config_file_path=config)
|
general_settings = _config.get("general_settings", {})
|
||||||
)
|
if general_settings is None:
|
||||||
|
general_settings = {}
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
db_connection_pool_limit = general_settings.get(
|
db_connection_pool_limit = general_settings.get(
|
||||||
"database_connection_pool_limit", 100
|
"database_connection_pool_limit", 100
|
||||||
|
|
|
@ -1,49 +1,9 @@
|
||||||
model_list:
|
model_list:
|
||||||
|
|
||||||
- model_name: gpt-3.5-turbo
|
|
||||||
litellm_params:
|
|
||||||
model: azure/chatgpt-v-2
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
api_version: "2023-05-15"
|
|
||||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
|
||||||
- model_name: gpt-3.5-turbo-large
|
|
||||||
litellm_params:
|
|
||||||
model: "gpt-3.5-turbo-1106"
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
- model_name: gpt-4
|
|
||||||
litellm_params:
|
|
||||||
model: azure/chatgpt-v-2
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
api_version: "2023-05-15"
|
|
||||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
|
||||||
- model_name: sagemaker-completion-model
|
|
||||||
litellm_params:
|
|
||||||
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
|
||||||
input_cost_per_second: 0.000420
|
|
||||||
- model_name: text-embedding-ada-002
|
|
||||||
litellm_params:
|
|
||||||
model: azure/azure-embedding-model
|
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
api_version: "2023-05-15"
|
|
||||||
model_info:
|
|
||||||
mode: embedding
|
|
||||||
base_model: text-embedding-ada-002
|
|
||||||
- model_name: dall-e-2
|
|
||||||
litellm_params:
|
|
||||||
model: azure/
|
|
||||||
api_version: 2023-06-01-preview
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
- model_name: openai-dall-e-3
|
|
||||||
litellm_params:
|
|
||||||
model: dall-e-3
|
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
litellm_settings:
|
|
||||||
success_callback: ["prometheus"]
|
|
||||||
general_settings:
|
general_settings:
|
||||||
|
store_model_in_db: true
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
File diff suppressed because it is too large
Load diff
|
@ -53,6 +53,7 @@ model LiteLLM_OrganizationTable {
|
||||||
updated_by String
|
updated_by String
|
||||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
teams LiteLLM_TeamTable[]
|
teams LiteLLM_TeamTable[]
|
||||||
|
users LiteLLM_UserTable[]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model info for teams, just has model aliases for now.
|
// Model info for teams, just has model aliases for now.
|
||||||
|
@ -99,6 +100,7 @@ model LiteLLM_UserTable {
|
||||||
user_id String @id
|
user_id String @id
|
||||||
user_alias String?
|
user_alias String?
|
||||||
team_id String?
|
team_id String?
|
||||||
|
organization_id String?
|
||||||
teams String[] @default([])
|
teams String[] @default([])
|
||||||
user_role String?
|
user_role String?
|
||||||
max_budget Float?
|
max_budget Float?
|
||||||
|
@ -113,6 +115,7 @@ model LiteLLM_UserTable {
|
||||||
allowed_cache_controls String[] @default([])
|
allowed_cache_controls String[] @default([])
|
||||||
model_spend Json @default("{}")
|
model_spend Json @default("{}")
|
||||||
model_max_budget Json @default("{}")
|
model_max_budget Json @default("{}")
|
||||||
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate Tokens for Proxy
|
// Generate Tokens for Proxy
|
||||||
|
|
126
litellm/proxy/tests/test_openai_embedding.py
Normal file
126
litellm/proxy/tests/test_openai_embedding.py
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
import openai
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request(client, model, input_data):
|
||||||
|
response = await client.embeddings.create(model=model, input=input_data)
|
||||||
|
response = response.dict()
|
||||||
|
data_list = response["data"]
|
||||||
|
for i, embedding in enumerate(data_list):
|
||||||
|
embedding["embedding"] = []
|
||||||
|
current_index = embedding["index"]
|
||||||
|
assert i == current_index
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
client = openai.AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
models = [
|
||||||
|
"text-embedding-ada-002",
|
||||||
|
"text-embedding-ada-002",
|
||||||
|
"text-embedding-ada-002",
|
||||||
|
]
|
||||||
|
inputs = [
|
||||||
|
[
|
||||||
|
"5",
|
||||||
|
"6",
|
||||||
|
"7",
|
||||||
|
"8",
|
||||||
|
"9",
|
||||||
|
"10",
|
||||||
|
"11",
|
||||||
|
"12",
|
||||||
|
"13",
|
||||||
|
"14",
|
||||||
|
"15",
|
||||||
|
"16",
|
||||||
|
"17",
|
||||||
|
"18",
|
||||||
|
"19",
|
||||||
|
"20",
|
||||||
|
],
|
||||||
|
["1", "2", "3", "4", "5", "6"],
|
||||||
|
[
|
||||||
|
"1",
|
||||||
|
"2",
|
||||||
|
"3",
|
||||||
|
"4",
|
||||||
|
"5",
|
||||||
|
"6",
|
||||||
|
"7",
|
||||||
|
"8",
|
||||||
|
"9",
|
||||||
|
"10",
|
||||||
|
"11",
|
||||||
|
"12",
|
||||||
|
"13",
|
||||||
|
"14",
|
||||||
|
"15",
|
||||||
|
"16",
|
||||||
|
"17",
|
||||||
|
"18",
|
||||||
|
"19",
|
||||||
|
"20",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"1",
|
||||||
|
"2",
|
||||||
|
"3",
|
||||||
|
"4",
|
||||||
|
"5",
|
||||||
|
"6",
|
||||||
|
"7",
|
||||||
|
"8",
|
||||||
|
"9",
|
||||||
|
"10",
|
||||||
|
"11",
|
||||||
|
"12",
|
||||||
|
"13",
|
||||||
|
"14",
|
||||||
|
"15",
|
||||||
|
"16",
|
||||||
|
"17",
|
||||||
|
"18",
|
||||||
|
"19",
|
||||||
|
"20",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"1",
|
||||||
|
"2",
|
||||||
|
"3",
|
||||||
|
"4",
|
||||||
|
"5",
|
||||||
|
"6",
|
||||||
|
"7",
|
||||||
|
"8",
|
||||||
|
"9",
|
||||||
|
"10",
|
||||||
|
"11",
|
||||||
|
"12",
|
||||||
|
"13",
|
||||||
|
"14",
|
||||||
|
"15",
|
||||||
|
"16",
|
||||||
|
"17",
|
||||||
|
"18",
|
||||||
|
"19",
|
||||||
|
"20",
|
||||||
|
],
|
||||||
|
["1", "2", "3"],
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for model, input_data in zip(models, inputs):
|
||||||
|
task = async_request(client, model, input_data)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
responses = await asyncio.gather(*tasks)
|
||||||
|
print(responses)
|
||||||
|
for response in responses:
|
||||||
|
data_list = response["data"]
|
||||||
|
for embedding in data_list:
|
||||||
|
embedding["embedding"] = []
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(main())
|
10
litellm/proxy/tests/test_openai_simple_embedding.py
Normal file
10
litellm/proxy/tests/test_openai_simple_embedding.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
# # request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="text-embedding-ada-002", input=["test"], encoding_format="base64"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
|
@ -461,7 +461,12 @@ class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
if isinstance(original_exception, HTTPException):
|
if isinstance(original_exception, HTTPException):
|
||||||
error_message = original_exception.detail
|
if isinstance(original_exception.detail, str):
|
||||||
|
error_message = original_exception.detail
|
||||||
|
elif isinstance(original_exception.detail, dict):
|
||||||
|
error_message = json.dumps(original_exception.detail)
|
||||||
|
else:
|
||||||
|
error_message = str(original_exception)
|
||||||
else:
|
else:
|
||||||
error_message = str(original_exception)
|
error_message = str(original_exception)
|
||||||
if isinstance(traceback_str, str):
|
if isinstance(traceback_str, str):
|
||||||
|
@ -562,6 +567,7 @@ class PrismaClient:
|
||||||
end_user_list_transactons: dict = {}
|
end_user_list_transactons: dict = {}
|
||||||
key_list_transactons: dict = {}
|
key_list_transactons: dict = {}
|
||||||
team_list_transactons: dict = {}
|
team_list_transactons: dict = {}
|
||||||
|
org_list_transactons: dict = {}
|
||||||
spend_log_transactions: List = []
|
spend_log_transactions: List = []
|
||||||
|
|
||||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||||
|
@ -1159,13 +1165,26 @@ class PrismaClient:
|
||||||
return new_verification_token
|
return new_verification_token
|
||||||
elif table_name == "user":
|
elif table_name == "user":
|
||||||
db_data = self.jsonify_object(data=data)
|
db_data = self.jsonify_object(data=data)
|
||||||
new_user_row = await self.db.litellm_usertable.upsert(
|
try:
|
||||||
where={"user_id": data["user_id"]},
|
new_user_row = await self.db.litellm_usertable.upsert(
|
||||||
data={
|
where={"user_id": data["user_id"]},
|
||||||
"create": {**db_data}, # type: ignore
|
data={
|
||||||
"update": {}, # don't do anything if it already exists
|
"create": {**db_data}, # type: ignore
|
||||||
},
|
"update": {}, # don't do anything if it already exists
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if (
|
||||||
|
"Foreign key constraint failed on the field: `LiteLLM_UserTable_organization_id_fkey (index)`"
|
||||||
|
in str(e)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"Foreign Key Constraint failed. Organization ID={db_data['organization_id']} does not exist in LiteLLM_OrganizationTable. Create via `/organization/new`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise e
|
||||||
verbose_proxy_logger.info("Data Inserted into User Table")
|
verbose_proxy_logger.info("Data Inserted into User Table")
|
||||||
return new_user_row
|
return new_user_row
|
||||||
elif table_name == "team":
|
elif table_name == "team":
|
||||||
|
@ -2132,6 +2151,46 @@ async def update_spend(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
### UPDATE ORG TABLE ###
|
||||||
|
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||||
|
for i in range(n_retry_times + 1):
|
||||||
|
try:
|
||||||
|
async with prisma_client.db.tx(
|
||||||
|
timeout=timedelta(seconds=60)
|
||||||
|
) as transaction:
|
||||||
|
async with transaction.batch_() as batcher:
|
||||||
|
for (
|
||||||
|
org_id,
|
||||||
|
response_cost,
|
||||||
|
) in prisma_client.org_list_transactons.items():
|
||||||
|
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||||
|
where={"organization_id": org_id},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
prisma_client.org_list_transactons = (
|
||||||
|
{}
|
||||||
|
) # Clear the remaining transactions after processing all batches in the loop.
|
||||||
|
break
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
|
raise # Re-raise the last exception
|
||||||
|
# Optionally, sleep for a bit before retrying
|
||||||
|
await asyncio.sleep(2**i) # Exponential backoff
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
error_msg = (
|
||||||
|
f"LiteLLM Prisma Client Exception - update org spend: {str(e)}"
|
||||||
|
)
|
||||||
|
print_verbose(error_msg)
|
||||||
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.failure_handler(
|
||||||
|
original_exception=e, traceback_str=error_traceback
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
### UPDATE SPEND LOGS ###
|
### UPDATE SPEND LOGS ###
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
|
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
|
||||||
|
|
|
@ -11,9 +11,9 @@ import copy, httpx
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
|
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
|
||||||
import random, threading, time, traceback, uuid
|
import random, threading, time, traceback, uuid
|
||||||
import litellm, openai
|
import litellm, openai, hashlib, json
|
||||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||||
|
import datetime as datetime_og
|
||||||
import logging, asyncio
|
import logging, asyncio
|
||||||
import inspect, concurrent
|
import inspect, concurrent
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
@ -21,15 +21,16 @@ from collections import defaultdict
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||||
CustomHTTPTransport,
|
CustomHTTPTransport,
|
||||||
AsyncCustomHTTPTransport,
|
AsyncCustomHTTPTransport,
|
||||||
)
|
)
|
||||||
from litellm.utils import ModelResponse, CustomStreamWrapper
|
from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
import logging
|
import logging
|
||||||
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params
|
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -77,6 +78,7 @@ class Router:
|
||||||
"latency-based-routing",
|
"latency-based-routing",
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||||
|
semaphore: Optional[asyncio.Semaphore] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
||||||
|
@ -142,6 +144,8 @@ class Router:
|
||||||
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
|
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if semaphore:
|
||||||
|
self.semaphore = semaphore
|
||||||
self.set_verbose = set_verbose
|
self.set_verbose = set_verbose
|
||||||
self.debug_level = debug_level
|
self.debug_level = debug_level
|
||||||
self.enable_pre_call_checks = enable_pre_call_checks
|
self.enable_pre_call_checks = enable_pre_call_checks
|
||||||
|
@ -273,6 +277,12 @@ class Router:
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
||||||
|
elif routing_strategy == "usage-based-routing-v2":
|
||||||
|
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
|
||||||
|
router_cache=self.cache, model_list=self.model_list
|
||||||
|
)
|
||||||
|
if isinstance(litellm.callbacks, list):
|
||||||
|
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
||||||
elif routing_strategy == "latency-based-routing":
|
elif routing_strategy == "latency-based-routing":
|
||||||
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
self.lowestlatency_logger = LowestLatencyLoggingHandler(
|
||||||
router_cache=self.cache,
|
router_cache=self.cache,
|
||||||
|
@ -402,12 +412,19 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||||
|
"""
|
||||||
|
- Get an available deployment
|
||||||
|
- call it with a semaphore over the call
|
||||||
|
- semaphore specific to it's rpm
|
||||||
|
- in the semaphore, make a check against it's local rpm before running
|
||||||
|
"""
|
||||||
model_name = None
|
model_name = None
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
|
||||||
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -436,6 +453,7 @@ class Router:
|
||||||
potential_model_client = self._get_client(
|
potential_model_client = self._get_client(
|
||||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if provided keys == client keys #
|
# check if provided keys == client keys #
|
||||||
dynamic_api_key = kwargs.get("api_key", None)
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
if (
|
if (
|
||||||
|
@ -458,7 +476,7 @@ class Router:
|
||||||
) # this uses default_litellm_params when nothing is set
|
) # this uses default_litellm_params when nothing is set
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await litellm.acompletion(
|
_response = litellm.acompletion(
|
||||||
**{
|
**{
|
||||||
**data,
|
**data,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -468,6 +486,25 @@ class Router:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rpm_semaphore = self._get_client(
|
||||||
|
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
rpm_semaphore is not None
|
||||||
|
and isinstance(rpm_semaphore, asyncio.Semaphore)
|
||||||
|
and self.routing_strategy == "usage-based-routing-v2"
|
||||||
|
):
|
||||||
|
async with rpm_semaphore:
|
||||||
|
"""
|
||||||
|
- Check rpm limits before making the call
|
||||||
|
"""
|
||||||
|
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
|
||||||
|
response = await _response
|
||||||
|
else:
|
||||||
|
response = await _response
|
||||||
|
|
||||||
self.success_calls[model_name] += 1
|
self.success_calls[model_name] += 1
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
|
@ -581,7 +618,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "prompt"}],
|
messages=[{"role": "user", "content": "prompt"}],
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -681,7 +718,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
|
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "prompt"}],
|
messages=[{"role": "user", "content": "prompt"}],
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -761,7 +798,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -904,7 +941,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -1070,7 +1107,7 @@ class Router:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
|
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
@ -1258,6 +1295,8 @@ class Router:
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(timeout)
|
await asyncio.sleep(timeout)
|
||||||
|
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
|
||||||
|
raise e # don't wait to retry if deployment hits user-defined rate-limit
|
||||||
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||||
status_code=original_exception.status_code
|
status_code=original_exception.status_code
|
||||||
):
|
):
|
||||||
|
@ -1598,7 +1637,8 @@ class Router:
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_minute = datetime.now().strftime("%H-%M")
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
# get current fails for deployment
|
# get current fails for deployment
|
||||||
# update the number of failed calls
|
# update the number of failed calls
|
||||||
# if it's > allowed fails
|
# if it's > allowed fails
|
||||||
|
@ -1636,11 +1676,29 @@ class Router:
|
||||||
key=deployment, value=updated_fails, ttl=cooldown_time
|
key=deployment, value=updated_fails, ttl=cooldown_time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _async_get_cooldown_deployments(self):
|
||||||
|
"""
|
||||||
|
Async implementation of '_get_cooldown_deployments'
|
||||||
|
"""
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
# get the current cooldown list for that minute
|
||||||
|
cooldown_key = f"{current_minute}:cooldown_models"
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Return cooldown models
|
||||||
|
# ----------------------
|
||||||
|
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
||||||
|
|
||||||
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
|
return cooldown_models
|
||||||
|
|
||||||
def _get_cooldown_deployments(self):
|
def _get_cooldown_deployments(self):
|
||||||
"""
|
"""
|
||||||
Get the list of models being cooled down for this minute
|
Get the list of models being cooled down for this minute
|
||||||
"""
|
"""
|
||||||
current_minute = datetime.now().strftime("%H-%M")
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
# get the current cooldown list for that minute
|
# get the current cooldown list for that minute
|
||||||
cooldown_key = f"{current_minute}:cooldown_models"
|
cooldown_key = f"{current_minute}:cooldown_models"
|
||||||
|
|
||||||
|
@ -1654,12 +1712,26 @@ class Router:
|
||||||
|
|
||||||
def set_client(self, model: dict):
|
def set_client(self, model: dict):
|
||||||
"""
|
"""
|
||||||
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||||
|
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||||
"""
|
"""
|
||||||
client_ttl = self.client_ttl
|
client_ttl = self.client_ttl
|
||||||
litellm_params = model.get("litellm_params", {})
|
litellm_params = model.get("litellm_params", {})
|
||||||
model_name = litellm_params.get("model")
|
model_name = litellm_params.get("model")
|
||||||
model_id = model["model_info"]["id"]
|
model_id = model["model_info"]["id"]
|
||||||
|
# ### IF RPM SET - initialize a semaphore ###
|
||||||
|
rpm = litellm_params.get("rpm", None)
|
||||||
|
if rpm:
|
||||||
|
semaphore = asyncio.Semaphore(rpm)
|
||||||
|
cache_key = f"{model_id}_rpm_client"
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=semaphore,
|
||||||
|
local_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# print("STORES SEMAPHORE IN CACHE")
|
||||||
|
|
||||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||||
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||||
|
@ -1874,8 +1946,12 @@ class Router:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
else:
|
else:
|
||||||
|
_api_key = api_key
|
||||||
|
if _api_key is not None and isinstance(_api_key, str):
|
||||||
|
# only show first 5 chars of api_key
|
||||||
|
_api_key = _api_key[:8] + "*" * 15
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
|
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
|
||||||
)
|
)
|
||||||
azure_client_params = {
|
azure_client_params = {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
@ -1972,8 +2048,12 @@ class Router:
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
_api_key = api_key
|
||||||
|
if _api_key is not None and isinstance(_api_key, str):
|
||||||
|
# only show first 5 chars of api_key
|
||||||
|
_api_key = _api_key[:8] + "*" * 15
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}"
|
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}"
|
||||||
)
|
)
|
||||||
cache_key = f"{model_id}_async_client"
|
cache_key = f"{model_id}_async_client"
|
||||||
_client = openai.AsyncOpenAI( # type: ignore
|
_client = openai.AsyncOpenAI( # type: ignore
|
||||||
|
@ -2065,6 +2145,34 @@ class Router:
|
||||||
local_only=True,
|
local_only=True,
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
|
def _generate_model_id(self, model_group: str, litellm_params: dict):
|
||||||
|
"""
|
||||||
|
Helper function to consistently generate the same id for a deployment
|
||||||
|
|
||||||
|
- create a string from all the litellm params
|
||||||
|
- hash
|
||||||
|
- use hash as id
|
||||||
|
"""
|
||||||
|
concat_str = model_group
|
||||||
|
for k, v in litellm_params.items():
|
||||||
|
if isinstance(k, str):
|
||||||
|
concat_str += k
|
||||||
|
elif isinstance(k, dict):
|
||||||
|
concat_str += json.dumps(k)
|
||||||
|
else:
|
||||||
|
concat_str += str(k)
|
||||||
|
|
||||||
|
if isinstance(v, str):
|
||||||
|
concat_str += v
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
concat_str += json.dumps(v)
|
||||||
|
else:
|
||||||
|
concat_str += str(v)
|
||||||
|
|
||||||
|
hash_object = hashlib.sha256(concat_str.encode())
|
||||||
|
|
||||||
|
return hash_object.hexdigest()
|
||||||
|
|
||||||
def set_model_list(self, model_list: list):
|
def set_model_list(self, model_list: list):
|
||||||
original_model_list = copy.deepcopy(model_list)
|
original_model_list = copy.deepcopy(model_list)
|
||||||
self.model_list = []
|
self.model_list = []
|
||||||
|
@ -2080,7 +2188,13 @@ class Router:
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
_litellm_params[k] = litellm.get_secret(v)
|
_litellm_params[k] = litellm.get_secret(v)
|
||||||
|
|
||||||
_model_info = model.pop("model_info", {})
|
_model_info: dict = model.pop("model_info", {})
|
||||||
|
|
||||||
|
# check if model info has id
|
||||||
|
if "id" not in _model_info:
|
||||||
|
_id = self._generate_model_id(_model_name, _litellm_params)
|
||||||
|
_model_info["id"] = _id
|
||||||
|
|
||||||
deployment = Deployment(
|
deployment = Deployment(
|
||||||
**model,
|
**model,
|
||||||
model_name=_model_name,
|
model_name=_model_name,
|
||||||
|
@ -2207,7 +2321,11 @@ class Router:
|
||||||
The appropriate client based on the given client_type and kwargs.
|
The appropriate client based on the given client_type and kwargs.
|
||||||
"""
|
"""
|
||||||
model_id = deployment["model_info"]["id"]
|
model_id = deployment["model_info"]["id"]
|
||||||
if client_type == "async":
|
if client_type == "rpm_client":
|
||||||
|
cache_key = "{}_rpm_client".format(model_id)
|
||||||
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
return client
|
||||||
|
elif client_type == "async":
|
||||||
if kwargs.get("stream") == True:
|
if kwargs.get("stream") == True:
|
||||||
cache_key = f"{model_id}_stream_async_client"
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
@ -2260,6 +2378,7 @@ class Router:
|
||||||
Filter out model in model group, if:
|
Filter out model in model group, if:
|
||||||
|
|
||||||
- model context window < message length
|
- model context window < message length
|
||||||
|
- filter models above rpm limits
|
||||||
- [TODO] function call and model doesn't support function calling
|
- [TODO] function call and model doesn't support function calling
|
||||||
"""
|
"""
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
|
@ -2279,11 +2398,12 @@ class Router:
|
||||||
_rate_limit_error = False
|
_rate_limit_error = False
|
||||||
|
|
||||||
## get model group RPM ##
|
## get model group RPM ##
|
||||||
current_minute = datetime.now().strftime("%H-%M")
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
rpm_key = f"{model}:rpm:{current_minute}"
|
rpm_key = f"{model}:rpm:{current_minute}"
|
||||||
model_group_cache = (
|
model_group_cache = (
|
||||||
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
||||||
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
|
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
|
||||||
for idx, deployment in enumerate(_returned_deployments):
|
for idx, deployment in enumerate(_returned_deployments):
|
||||||
# see if we have the info for this model
|
# see if we have the info for this model
|
||||||
try:
|
try:
|
||||||
|
@ -2296,20 +2416,20 @@ class Router:
|
||||||
"model", None
|
"model", None
|
||||||
)
|
)
|
||||||
model_info = litellm.get_model_info(model=model)
|
model_info = litellm.get_model_info(model=model)
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(model_info, dict)
|
|
||||||
and model_info.get("max_input_tokens", None) is not None
|
|
||||||
):
|
|
||||||
if (
|
if (
|
||||||
isinstance(model_info["max_input_tokens"], int)
|
isinstance(model_info, dict)
|
||||||
and input_tokens > model_info["max_input_tokens"]
|
and model_info.get("max_input_tokens", None) is not None
|
||||||
):
|
):
|
||||||
invalid_model_indices.append(idx)
|
if (
|
||||||
_context_window_error = True
|
isinstance(model_info["max_input_tokens"], int)
|
||||||
continue
|
and input_tokens > model_info["max_input_tokens"]
|
||||||
|
):
|
||||||
|
invalid_model_indices.append(idx)
|
||||||
|
_context_window_error = True
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
||||||
|
|
||||||
## RPM CHECK ##
|
## RPM CHECK ##
|
||||||
_litellm_params = deployment.get("litellm_params", {})
|
_litellm_params = deployment.get("litellm_params", {})
|
||||||
|
@ -2319,23 +2439,24 @@ class Router:
|
||||||
self.cache.get_cache(key=model_id, local_only=True) or 0
|
self.cache.get_cache(key=model_id, local_only=True) or 0
|
||||||
)
|
)
|
||||||
### get usage based cache ###
|
### get usage based cache ###
|
||||||
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
|
if isinstance(model_group_cache, dict):
|
||||||
|
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
|
||||||
|
|
||||||
current_request = max(
|
current_request = max(
|
||||||
current_request_cache_local, model_group_cache[model_id]
|
current_request_cache_local, model_group_cache[model_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(_litellm_params, dict)
|
|
||||||
and _litellm_params.get("rpm", None) is not None
|
|
||||||
):
|
|
||||||
if (
|
if (
|
||||||
isinstance(_litellm_params["rpm"], int)
|
isinstance(_litellm_params, dict)
|
||||||
and _litellm_params["rpm"] <= current_request
|
and _litellm_params.get("rpm", None) is not None
|
||||||
):
|
):
|
||||||
invalid_model_indices.append(idx)
|
if (
|
||||||
_rate_limit_error = True
|
isinstance(_litellm_params["rpm"], int)
|
||||||
continue
|
and _litellm_params["rpm"] <= current_request
|
||||||
|
):
|
||||||
|
invalid_model_indices.append(idx)
|
||||||
|
_rate_limit_error = True
|
||||||
|
continue
|
||||||
|
|
||||||
if len(invalid_model_indices) == len(_returned_deployments):
|
if len(invalid_model_indices) == len(_returned_deployments):
|
||||||
"""
|
"""
|
||||||
|
@ -2364,7 +2485,7 @@ class Router:
|
||||||
|
|
||||||
return _returned_deployments
|
return _returned_deployments
|
||||||
|
|
||||||
def get_available_deployment(
|
def _common_checks_available_deployment(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Optional[List[Dict[str, str]]] = None,
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
@ -2372,11 +2493,11 @@ class Router:
|
||||||
specific_deployment: Optional[bool] = False,
|
specific_deployment: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Returns the deployment based on routing strategy
|
Common checks for 'get_available_deployment' across sync + async call.
|
||||||
"""
|
|
||||||
|
|
||||||
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
|
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
||||||
# When this was no explicit we had several issues with fallbacks timing out
|
"""
|
||||||
|
# check if aliases set on litellm model alias map
|
||||||
if specific_deployment == True:
|
if specific_deployment == True:
|
||||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
||||||
for deployment in self.model_list:
|
for deployment in self.model_list:
|
||||||
|
@ -2384,12 +2505,11 @@ class Router:
|
||||||
if deployment_model == model:
|
if deployment_model == model:
|
||||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||||
# return the first deployment where the `model` matches the specificed deployment name
|
# return the first deployment where the `model` matches the specificed deployment name
|
||||||
return deployment
|
return deployment, None
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
|
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if aliases set on litellm model alias map
|
|
||||||
if model in self.model_group_alias:
|
if model in self.model_group_alias:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
|
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
|
||||||
|
@ -2401,7 +2521,7 @@ class Router:
|
||||||
self.default_deployment
|
self.default_deployment
|
||||||
) # self.default_deployment
|
) # self.default_deployment
|
||||||
updated_deployment["litellm_params"]["model"] = model
|
updated_deployment["litellm_params"]["model"] = model
|
||||||
return updated_deployment
|
return updated_deployment, None
|
||||||
|
|
||||||
## get healthy deployments
|
## get healthy deployments
|
||||||
### get all deployments
|
### get all deployments
|
||||||
|
@ -2416,6 +2536,118 @@ class Router:
|
||||||
f"initial list of deployments: {healthy_deployments}"
|
f"initial list of deployments: {healthy_deployments}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
|
||||||
|
)
|
||||||
|
if len(healthy_deployments) == 0:
|
||||||
|
raise ValueError(f"No healthy deployment available, passed model={model}")
|
||||||
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[
|
||||||
|
model
|
||||||
|
] # update the model to the actual value if an alias has been passed in
|
||||||
|
|
||||||
|
return model, healthy_deployments
|
||||||
|
|
||||||
|
async def async_get_available_deployment(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
specific_deployment: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async implementation of 'get_available_deployments'.
|
||||||
|
|
||||||
|
Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps).
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.routing_strategy != "usage-based-routing-v2"
|
||||||
|
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
||||||
|
return self.get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=specific_deployment,
|
||||||
|
)
|
||||||
|
|
||||||
|
model, healthy_deployments = self._common_checks_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=specific_deployment,
|
||||||
|
)
|
||||||
|
|
||||||
|
if healthy_deployments is None:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# filter out the deployments currently cooling down
|
||||||
|
deployments_to_remove = []
|
||||||
|
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
||||||
|
cooldown_deployments = await self._async_get_cooldown_deployments()
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"async cooldown deployments: {cooldown_deployments}"
|
||||||
|
)
|
||||||
|
# Find deployments in model_list whose model_id is cooling down
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
deployment_id = deployment["model_info"]["id"]
|
||||||
|
if deployment_id in cooldown_deployments:
|
||||||
|
deployments_to_remove.append(deployment)
|
||||||
|
# remove unhealthy deployments from healthy deployments
|
||||||
|
for deployment in deployments_to_remove:
|
||||||
|
healthy_deployments.remove(deployment)
|
||||||
|
|
||||||
|
# filter pre-call checks
|
||||||
|
if self.enable_pre_call_checks and messages is not None:
|
||||||
|
healthy_deployments = self._pre_call_checks(
|
||||||
|
model=model, healthy_deployments=healthy_deployments, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.routing_strategy == "usage-based-routing-v2"
|
||||||
|
and self.lowesttpm_logger_v2 is not None
|
||||||
|
):
|
||||||
|
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
|
||||||
|
model_group=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
if deployment is None:
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"No deployments available for selected model, passed model={model}"
|
||||||
|
)
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||||
|
)
|
||||||
|
return deployment
|
||||||
|
|
||||||
|
def get_available_deployment(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
specific_deployment: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns the deployment based on routing strategy
|
||||||
|
"""
|
||||||
|
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
|
||||||
|
# When this was no explicit we had several issues with fallbacks timing out
|
||||||
|
|
||||||
|
model, healthy_deployments = self._common_checks_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=specific_deployment,
|
||||||
|
)
|
||||||
|
|
||||||
|
if healthy_deployments is None:
|
||||||
|
return model
|
||||||
|
|
||||||
# filter out the deployments currently cooling down
|
# filter out the deployments currently cooling down
|
||||||
deployments_to_remove = []
|
deployments_to_remove = []
|
||||||
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
||||||
|
@ -2436,16 +2668,6 @@ class Router:
|
||||||
model=model, healthy_deployments=healthy_deployments, messages=messages
|
model=model, healthy_deployments=healthy_deployments, messages=messages
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_router_logger.debug(
|
|
||||||
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
|
|
||||||
)
|
|
||||||
if len(healthy_deployments) == 0:
|
|
||||||
raise ValueError(f"No healthy deployment available, passed model={model}")
|
|
||||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
|
||||||
model = litellm.model_alias_map[
|
|
||||||
model
|
|
||||||
] # update the model to the actual value if an alias has been passed in
|
|
||||||
|
|
||||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||||
deployment = self.leastbusy_logger.get_available_deployments(
|
deployment = self.leastbusy_logger.get_available_deployments(
|
||||||
model_group=model, healthy_deployments=healthy_deployments
|
model_group=model, healthy_deployments=healthy_deployments
|
||||||
|
@ -2507,7 +2729,16 @@ class Router:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
self.routing_strategy == "usage-based-routing-v2"
|
||||||
|
and self.lowesttpm_logger_v2 is not None
|
||||||
|
):
|
||||||
|
deployment = self.lowesttpm_logger_v2.get_available_deployments(
|
||||||
|
model_group=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, No deployment available"
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
|
|
403
litellm/router_strategy/lowest_tpm_rpm_v2.py
Normal file
403
litellm/router_strategy/lowest_tpm_rpm_v2.py
Normal file
|
@ -0,0 +1,403 @@
|
||||||
|
#### What this does ####
|
||||||
|
# identifies lowest tpm deployment
|
||||||
|
|
||||||
|
import dotenv, os, requests, random
|
||||||
|
from typing import Optional, Union, List, Dict
|
||||||
|
import datetime as datetime_og
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
import traceback, asyncio, httpx
|
||||||
|
import litellm
|
||||||
|
from litellm import token_counter
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm._logging import verbose_router_logger
|
||||||
|
from litellm.utils import print_verbose, get_utc_datetime
|
||||||
|
from litellm.types.router import RouterErrors
|
||||||
|
|
||||||
|
|
||||||
|
class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
|
"""
|
||||||
|
Updated version of TPM/RPM Logging.
|
||||||
|
|
||||||
|
Meant to work across instances.
|
||||||
|
|
||||||
|
Caches individual models, not model_groups
|
||||||
|
|
||||||
|
Uses batch get (redis.mget)
|
||||||
|
|
||||||
|
Increments tpm/rpm limit using redis.incr
|
||||||
|
"""
|
||||||
|
|
||||||
|
test_flag: bool = False
|
||||||
|
logged_success: int = 0
|
||||||
|
logged_failure: int = 0
|
||||||
|
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||||
|
|
||||||
|
def __init__(self, router_cache: DualCache, model_list: list):
|
||||||
|
self.router_cache = router_cache
|
||||||
|
self.model_list = model_list
|
||||||
|
|
||||||
|
async def pre_call_rpm_check(self, deployment: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Pre-call check + update model rpm
|
||||||
|
- Used inside semaphore
|
||||||
|
- raise rate limit error if deployment over limit
|
||||||
|
|
||||||
|
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
|
||||||
|
|
||||||
|
Returns - deployment
|
||||||
|
|
||||||
|
Raises - RateLimitError if deployment over defined RPM limit
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
model_group = deployment.get("model_name", "")
|
||||||
|
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||||
|
local_result = await self.router_cache.async_get_cache(
|
||||||
|
key=rpm_key, local_only=True
|
||||||
|
) # check local result first
|
||||||
|
|
||||||
|
deployment_rpm = None
|
||||||
|
if deployment_rpm is None:
|
||||||
|
deployment_rpm = deployment.get("rpm")
|
||||||
|
if deployment_rpm is None:
|
||||||
|
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||||
|
if deployment_rpm is None:
|
||||||
|
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||||
|
if deployment_rpm is None:
|
||||||
|
deployment_rpm = float("inf")
|
||||||
|
|
||||||
|
if local_result is not None and local_result >= deployment_rpm:
|
||||||
|
raise litellm.RateLimitError(
|
||||||
|
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||||
|
deployment_rpm, local_result
|
||||||
|
),
|
||||||
|
llm_provider="",
|
||||||
|
model=deployment.get("litellm_params", {}).get("model"),
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
content="{} rpm limit={}. current usage={}".format(
|
||||||
|
RouterErrors.user_defined_ratelimit_error.value,
|
||||||
|
deployment_rpm,
|
||||||
|
local_result,
|
||||||
|
),
|
||||||
|
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||||
|
result = await self.router_cache.async_increment_cache(
|
||||||
|
key=rpm_key, value=1
|
||||||
|
)
|
||||||
|
if result is not None and result > deployment_rpm:
|
||||||
|
raise litellm.RateLimitError(
|
||||||
|
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||||
|
deployment_rpm, result
|
||||||
|
),
|
||||||
|
llm_provider="",
|
||||||
|
model=deployment.get("litellm_params", {}).get("model"),
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
content="{} rpm limit={}. current usage={}".format(
|
||||||
|
RouterErrors.user_defined_ratelimit_error.value,
|
||||||
|
deployment_rpm,
|
||||||
|
result,
|
||||||
|
),
|
||||||
|
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return deployment
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, litellm.RateLimitError):
|
||||||
|
raise e
|
||||||
|
return deployment # don't fail calls if eg. redis fails to connect
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
Update TPM/RPM usage on success
|
||||||
|
"""
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
elif isinstance(id, int):
|
||||||
|
id = str(id)
|
||||||
|
|
||||||
|
total_tokens = response_obj["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||||
|
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||||
|
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||||
|
|
||||||
|
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
|
||||||
|
|
||||||
|
## RPM
|
||||||
|
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||||
|
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||||
|
|
||||||
|
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_success += 1
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
Update TPM usage on success
|
||||||
|
"""
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
elif isinstance(id, int):
|
||||||
|
id = str(id)
|
||||||
|
|
||||||
|
total_tokens = response_obj["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime(
|
||||||
|
"%H-%M"
|
||||||
|
) # use the same timezone regardless of system clock
|
||||||
|
|
||||||
|
tpm_key = f"{id}:tpm:{current_minute}"
|
||||||
|
# ------------
|
||||||
|
# Update usage
|
||||||
|
# ------------
|
||||||
|
# update cache
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
await self.router_cache.async_increment_cache(
|
||||||
|
key=tpm_key, value=total_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
### TESTING ###
|
||||||
|
if self.test_flag:
|
||||||
|
self.logged_success += 1
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _common_checks_available_deployment(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
tpm_keys: list,
|
||||||
|
tpm_values: list,
|
||||||
|
rpm_keys: list,
|
||||||
|
rpm_values: list,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Common checks for get available deployment, across sync + async implementations
|
||||||
|
"""
|
||||||
|
tpm_dict = {} # {model_id: 1, ..}
|
||||||
|
for idx, key in enumerate(tpm_keys):
|
||||||
|
tpm_dict[tpm_keys[idx]] = tpm_values[idx]
|
||||||
|
|
||||||
|
rpm_dict = {} # {model_id: 1, ..}
|
||||||
|
for idx, key in enumerate(rpm_keys):
|
||||||
|
rpm_dict[rpm_keys[idx]] = rpm_values[idx]
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_tokens = token_counter(messages=messages, text=input)
|
||||||
|
except:
|
||||||
|
input_tokens = 0
|
||||||
|
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||||
|
# -----------------------
|
||||||
|
# Find lowest used model
|
||||||
|
# ----------------------
|
||||||
|
lowest_tpm = float("inf")
|
||||||
|
|
||||||
|
if tpm_dict is None: # base case - none of the deployments have been used
|
||||||
|
# initialize a tpm dict with {model_id: 0}
|
||||||
|
tpm_dict = {}
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||||
|
else:
|
||||||
|
for d in healthy_deployments:
|
||||||
|
## if healthy deployment not yet used
|
||||||
|
if d["model_info"]["id"] not in tpm_dict:
|
||||||
|
tpm_dict[d["model_info"]["id"]] = 0
|
||||||
|
|
||||||
|
all_deployments = tpm_dict
|
||||||
|
|
||||||
|
deployment = None
|
||||||
|
for item, item_tpm in all_deployments.items():
|
||||||
|
## get the item from model list
|
||||||
|
_deployment = None
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if item == m["model_info"]["id"]:
|
||||||
|
_deployment = m
|
||||||
|
|
||||||
|
if _deployment is None:
|
||||||
|
continue # skip to next one
|
||||||
|
|
||||||
|
_deployment_tpm = None
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = float("inf")
|
||||||
|
|
||||||
|
_deployment_rpm = None
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = float("inf")
|
||||||
|
|
||||||
|
if item_tpm + input_tokens > _deployment_tpm:
|
||||||
|
continue
|
||||||
|
elif (rpm_dict is not None and item in rpm_dict) and (
|
||||||
|
rpm_dict[item] + 1 > _deployment_rpm
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
elif item_tpm < lowest_tpm:
|
||||||
|
lowest_tpm = item_tpm
|
||||||
|
deployment = _deployment
|
||||||
|
print_verbose("returning picked lowest tpm/rpm deployment.")
|
||||||
|
return deployment
|
||||||
|
|
||||||
|
async def async_get_available_deployments(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async implementation of get deployments.
|
||||||
|
|
||||||
|
Reduces time to retrieve the tpm/rpm values from cache
|
||||||
|
"""
|
||||||
|
# get list of potential deployments
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||||
|
)
|
||||||
|
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
tpm_keys = []
|
||||||
|
rpm_keys = []
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if isinstance(m, dict):
|
||||||
|
id = m.get("model_info", {}).get(
|
||||||
|
"id"
|
||||||
|
) # a deployment should always have an 'id'. this is set in router.py
|
||||||
|
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
||||||
|
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||||
|
|
||||||
|
tpm_keys.append(tpm_key)
|
||||||
|
rpm_keys.append(rpm_key)
|
||||||
|
|
||||||
|
tpm_values = await self.router_cache.async_batch_get_cache(
|
||||||
|
keys=tpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
rpm_values = await self.router_cache.async_batch_get_cache(
|
||||||
|
keys=rpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
|
||||||
|
return self._common_checks_available_deployment(
|
||||||
|
model_group=model_group,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
tpm_keys=tpm_keys,
|
||||||
|
tpm_values=tpm_values,
|
||||||
|
rpm_keys=rpm_keys,
|
||||||
|
rpm_values=rpm_values,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_available_deployments(
|
||||||
|
self,
|
||||||
|
model_group: str,
|
||||||
|
healthy_deployments: list,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns a deployment with the lowest TPM/RPM usage.
|
||||||
|
"""
|
||||||
|
# get list of potential deployments
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||||
|
)
|
||||||
|
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
tpm_keys = []
|
||||||
|
rpm_keys = []
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if isinstance(m, dict):
|
||||||
|
id = m.get("model_info", {}).get(
|
||||||
|
"id"
|
||||||
|
) # a deployment should always have an 'id'. this is set in router.py
|
||||||
|
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
||||||
|
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||||
|
|
||||||
|
tpm_keys.append(tpm_key)
|
||||||
|
rpm_keys.append(rpm_key)
|
||||||
|
|
||||||
|
tpm_values = self.router_cache.batch_get_cache(
|
||||||
|
keys=tpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
rpm_values = self.router_cache.batch_get_cache(
|
||||||
|
keys=rpm_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
|
||||||
|
return self._common_checks_available_deployment(
|
||||||
|
model_group=model_group,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
tpm_keys=tpm_keys,
|
||||||
|
tpm_values=tpm_values,
|
||||||
|
rpm_keys=rpm_keys,
|
||||||
|
rpm_values=rpm_values,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
File diff suppressed because it is too large
Load diff
|
@ -345,6 +345,83 @@ async def test_embedding_caching_azure_individual_items():
|
||||||
assert embedding_val_2._hidden_params["cache_hit"] == True
|
assert embedding_val_2._hidden_params["cache_hit"] == True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_caching_azure_individual_items_reordered():
|
||||||
|
"""
|
||||||
|
Tests caching for individual items in an embedding list
|
||||||
|
|
||||||
|
- Cache an item
|
||||||
|
- call aembedding(..) with the item + 1 unique item
|
||||||
|
- compare to a 2nd aembedding (...) with 2 unique items
|
||||||
|
```
|
||||||
|
embedding_1 = ["hey how's it going", "I'm doing well"]
|
||||||
|
embedding_val_1 = embedding(...)
|
||||||
|
|
||||||
|
embedding_2 = ["hey how's it going", "I'm fine"]
|
||||||
|
embedding_val_2 = embedding(...)
|
||||||
|
|
||||||
|
assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
litellm.cache = Cache()
|
||||||
|
common_msg = f"{uuid.uuid4()}"
|
||||||
|
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
|
||||||
|
embedding_1 = [common_msg_2, common_msg]
|
||||||
|
embedding_2 = [
|
||||||
|
common_msg,
|
||||||
|
f"I'm fine {uuid.uuid4()}",
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_val_1 = await aembedding(
|
||||||
|
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
||||||
|
)
|
||||||
|
embedding_val_2 = await aembedding(
|
||||||
|
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
||||||
|
)
|
||||||
|
print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
|
||||||
|
assert embedding_val_2._hidden_params["cache_hit"] == True
|
||||||
|
|
||||||
|
assert embedding_val_2.data[0]["embedding"] == embedding_val_1.data[1]["embedding"]
|
||||||
|
assert embedding_val_2.data[0]["index"] != embedding_val_1.data[1]["index"]
|
||||||
|
assert embedding_val_2.data[0]["index"] == 0
|
||||||
|
assert embedding_val_1.data[1]["index"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_caching_base_64():
|
||||||
|
""" """
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
)
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
f"{uuid.uuid4()} hello this is ishaan",
|
||||||
|
f"{uuid.uuid4()} hello this is ishaan again",
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_val_1 = await aembedding(
|
||||||
|
model="azure/azure-embedding-model",
|
||||||
|
input=inputs,
|
||||||
|
caching=True,
|
||||||
|
encoding_format="base64",
|
||||||
|
)
|
||||||
|
embedding_val_2 = await aembedding(
|
||||||
|
model="azure/azure-embedding-model",
|
||||||
|
input=inputs,
|
||||||
|
caching=True,
|
||||||
|
encoding_format="base64",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding_val_2._hidden_params["cache_hit"] == True
|
||||||
|
print(embedding_val_2)
|
||||||
|
print(embedding_val_1)
|
||||||
|
assert embedding_val_2.data[0]["embedding"] == embedding_val_1.data[0]["embedding"]
|
||||||
|
assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_cache_basic():
|
async def test_redis_cache_basic():
|
||||||
"""
|
"""
|
||||||
|
@ -630,6 +707,39 @@ async def test_redis_cache_acompletion_stream():
|
||||||
# test_redis_cache_acompletion_stream()
|
# test_redis_cache_acompletion_stream()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_cache_atext_completion():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
prompt = f"write a one sentence poem about: {uuid.uuid4()}"
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
supported_call_types=["atext_completion"],
|
||||||
|
)
|
||||||
|
print("test for caching, atext_completion")
|
||||||
|
|
||||||
|
response1 = await litellm.atext_completion(
|
||||||
|
model="gpt-3.5-turbo-instruct", prompt=prompt, max_tokens=40, temperature=1
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
print("\n\n Response 1 content: ", response1, "\n\n")
|
||||||
|
|
||||||
|
response2 = await litellm.atext_completion(
|
||||||
|
model="gpt-3.5-turbo-instruct", prompt=prompt, max_tokens=40, temperature=1
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response2)
|
||||||
|
|
||||||
|
assert response1.id == response2.id
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{str(e)}\n\n{traceback.format_exc()}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_cache_acompletion_stream_bedrock():
|
async def test_redis_cache_acompletion_stream_bedrock():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
|
@ -596,7 +596,7 @@ def test_completion_gpt4_vision():
|
||||||
|
|
||||||
|
|
||||||
def test_completion_azure_gpt4_vision():
|
def test_completion_azure_gpt4_vision():
|
||||||
# azure/gpt-4, vision takes 5 seconds to respond
|
# azure/gpt-4, vision takes 5-seconds to respond
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = completion(
|
response = completion(
|
||||||
|
@ -975,6 +975,19 @@ def test_completion_text_openai():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_text_openai_async():
|
||||||
|
try:
|
||||||
|
# litellm.set_verbose =True
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo-instruct", messages=messages
|
||||||
|
)
|
||||||
|
print(response["choices"][0]["message"]["content"])
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def custom_callback(
|
def custom_callback(
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
completion_response, # response from completion
|
completion_response, # response from completion
|
||||||
|
@ -1619,9 +1632,9 @@ def test_completion_replicate_vicuna():
|
||||||
|
|
||||||
def test_replicate_custom_prompt_dict():
|
def test_replicate_custom_prompt_dict():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "replicate/meta/llama-2-7b-chat"
|
model_name = "replicate/meta/llama-2-70b-chat"
|
||||||
litellm.register_prompt_template(
|
litellm.register_prompt_template(
|
||||||
model="replicate/meta/llama-2-7b-chat",
|
model="replicate/meta/llama-2-70b-chat",
|
||||||
initial_prompt_value="You are a good assistant", # [OPTIONAL]
|
initial_prompt_value="You are a good assistant", # [OPTIONAL]
|
||||||
roles={
|
roles={
|
||||||
"system": {
|
"system": {
|
||||||
|
@ -1639,16 +1652,24 @@ def test_replicate_custom_prompt_dict():
|
||||||
},
|
},
|
||||||
final_prompt_value="Now answer as best you can:", # [OPTIONAL]
|
final_prompt_value="Now answer as best you can:", # [OPTIONAL]
|
||||||
)
|
)
|
||||||
response = completion(
|
try:
|
||||||
model=model_name,
|
response = completion(
|
||||||
messages=[
|
model=model_name,
|
||||||
{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": "what is yc write 1 paragraph",
|
"role": "user",
|
||||||
}
|
"content": "what is yc write 1 paragraph",
|
||||||
],
|
}
|
||||||
num_retries=3,
|
],
|
||||||
)
|
repetition_penalty=0.1,
|
||||||
|
num_retries=3,
|
||||||
|
)
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
|
except litellm.APIConnectionError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
litellm.custom_prompt_dict = {} # reset
|
litellm.custom_prompt_dict = {} # reset
|
||||||
|
|
||||||
|
|
|
@ -345,3 +345,187 @@ async def test_team_token_output(prisma_client):
|
||||||
assert team_result.team_tpm_limit == 100
|
assert team_result.team_tpm_limit == 100
|
||||||
assert team_result.team_rpm_limit == 99
|
assert team_result.team_rpm_limit == 99
|
||||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_token_output(prisma_client):
|
||||||
|
"""
|
||||||
|
- If user required, check if it exists
|
||||||
|
- fail initial request (when user doesn't exist)
|
||||||
|
- create user
|
||||||
|
- retry -> it should pass now
|
||||||
|
"""
|
||||||
|
import jwt, json
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user
|
||||||
|
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
|
||||||
|
import litellm
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
# Generate a private / public key pair using RSA algorithm
|
||||||
|
key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
|
)
|
||||||
|
# Get private key in PEM format
|
||||||
|
private_key = key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get public key in PEM format
|
||||||
|
public_key = key.public_key().public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
public_key_obj = serialization.load_pem_public_key(
|
||||||
|
public_key, backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert RSA public key object to JWK (JSON Web Key)
|
||||||
|
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
|
||||||
|
|
||||||
|
assert isinstance(public_jwk, dict)
|
||||||
|
|
||||||
|
# set cache
|
||||||
|
cache = DualCache()
|
||||||
|
|
||||||
|
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||||
|
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
|
|
||||||
|
jwt_handler.user_api_key_cache = cache
|
||||||
|
|
||||||
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||||
|
|
||||||
|
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
||||||
|
|
||||||
|
# VALID TOKEN
|
||||||
|
## GENERATE A TOKEN
|
||||||
|
# Assuming the current time is in UTC
|
||||||
|
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||||
|
|
||||||
|
team_id = f"team123_{uuid.uuid4()}"
|
||||||
|
user_id = f"user123_{uuid.uuid4()}"
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
|
"scope": "litellm_team",
|
||||||
|
"client_id": team_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate the JWT token
|
||||||
|
# But before, you should convert bytes to string
|
||||||
|
private_key_str = private_key.decode("utf-8")
|
||||||
|
|
||||||
|
## team token
|
||||||
|
token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
||||||
|
## admin token
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
|
"scope": "litellm_proxy_admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
||||||
|
## VERIFY IT WORKS
|
||||||
|
|
||||||
|
# verify token
|
||||||
|
|
||||||
|
response = await jwt_handler.auth_jwt(token=token)
|
||||||
|
|
||||||
|
## RUN IT THROUGH USER API KEY AUTH
|
||||||
|
|
||||||
|
"""
|
||||||
|
- 1. Initial call should fail -> team doesn't exist
|
||||||
|
- 2. Create team via admin token
|
||||||
|
- 3. 2nd call w/ same team -> call should fail -> user doesn't exist
|
||||||
|
- 4. Create user via admin token
|
||||||
|
- 5. 3rd call w/ same team, same user -> call should succeed
|
||||||
|
- 6. assert user api key auth format
|
||||||
|
"""
|
||||||
|
|
||||||
|
bearer_token = "Bearer " + token
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
## 1. INITIAL TEAM CALL - should fail
|
||||||
|
# use generated key to auth in
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
|
||||||
|
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||||
|
try:
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
pytest.fail("Team doesn't exist. This should fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed
|
||||||
|
try:
|
||||||
|
bearer_token = "Bearer " + admin_token
|
||||||
|
|
||||||
|
request._url = URL(url="/team/new")
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
await new_team(
|
||||||
|
data=NewTeamRequest(
|
||||||
|
team_id=team_id,
|
||||||
|
tpm_limit=100,
|
||||||
|
rpm_limit=99,
|
||||||
|
models=["gpt-3.5-turbo", "gpt-4"],
|
||||||
|
),
|
||||||
|
user_api_key_dict=result,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
|
|
||||||
|
## 3. 2nd CALL W/ TEAM TOKEN - should fail
|
||||||
|
bearer_token = "Bearer " + token
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
try:
|
||||||
|
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||||
|
request=request, api_key=bearer_token
|
||||||
|
)
|
||||||
|
pytest.fail(f"User doesn't exist. this should fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## 4. Create user
|
||||||
|
try:
|
||||||
|
bearer_token = "Bearer " + admin_token
|
||||||
|
|
||||||
|
request._url = URL(url="/team/new")
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
await new_user(
|
||||||
|
data=NewUserRequest(
|
||||||
|
user_id=user_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
|
|
||||||
|
## 5. 3rd call w/ same team, same user -> call should succeed
|
||||||
|
bearer_token = "Bearer " + token
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
try:
|
||||||
|
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||||
|
request=request, api_key=bearer_token
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Team exists. This should not fail - {e}")
|
||||||
|
|
||||||
|
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
||||||
|
|
||||||
|
assert team_result.team_tpm_limit == 100
|
||||||
|
assert team_result.team_rpm_limit == 99
|
||||||
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
|
assert team_result.user_id == user_id
|
||||||
|
|
|
@ -66,6 +66,7 @@ from litellm.proxy._types import (
|
||||||
GenerateKeyRequest,
|
GenerateKeyRequest,
|
||||||
NewTeamRequest,
|
NewTeamRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
|
LiteLLM_UpperboundKeyGenerateParams,
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import DBClient
|
from litellm.proxy.utils import DBClient
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
|
@ -1627,10 +1628,9 @@ async def test_upperbound_key_params(prisma_client):
|
||||||
"""
|
"""
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
litellm.upperbound_key_generate_params = {
|
litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams(
|
||||||
"max_budget": 0.001,
|
max_budget=0.001, budget_duration="1m"
|
||||||
"budget_duration": "1m",
|
)
|
||||||
}
|
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
try:
|
try:
|
||||||
request = GenerateKeyRequest(
|
request = GenerateKeyRequest(
|
||||||
|
@ -1638,18 +1638,9 @@ async def test_upperbound_key_params(prisma_client):
|
||||||
budget_duration="30d",
|
budget_duration="30d",
|
||||||
)
|
)
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(request)
|
||||||
generated_key = key.key
|
# print(result)
|
||||||
|
|
||||||
result = await info_key_fn(key=generated_key)
|
|
||||||
key_info = result["info"]
|
|
||||||
# assert it used the upper bound for max_budget, and budget_duration
|
|
||||||
assert key_info["max_budget"] == 0.001
|
|
||||||
assert key_info["budget_duration"] == "1m"
|
|
||||||
|
|
||||||
print(result)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Got Exception", e)
|
assert e.code == 400
|
||||||
pytest.fail(f"Got exception {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_bearer_token():
|
def test_get_bearer_token():
|
||||||
|
@ -1686,6 +1677,28 @@ def test_get_bearer_token():
|
||||||
assert result == "sk-1234", f"Expected 'valid_token', got '{result}'"
|
assert result == "sk-1234", f"Expected 'valid_token', got '{result}'"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_logs_with_spend_logs_url(prisma_client):
|
||||||
|
"""
|
||||||
|
Unit test for making sure spend logs list is still updated when url passed in
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import _set_spend_logs_payload
|
||||||
|
|
||||||
|
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
|
||||||
|
_set_spend_logs_payload(payload=payload, prisma_client=prisma_client)
|
||||||
|
|
||||||
|
assert len(prisma_client.spend_log_transactions) > 0
|
||||||
|
|
||||||
|
prisma_client.spend_log_transactions = []
|
||||||
|
|
||||||
|
spend_logs_url = ""
|
||||||
|
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
|
||||||
|
_set_spend_logs_payload(
|
||||||
|
payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(prisma_client.spend_log_transactions) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_api_key_auth(prisma_client):
|
async def test_user_api_key_auth(prisma_client):
|
||||||
from litellm.proxy.proxy_server import ProxyException
|
from litellm.proxy.proxy_server import ProxyException
|
||||||
|
|
|
@ -111,7 +111,10 @@ def test_llm_guard_key_specific_mode():
|
||||||
api_key=_api_key,
|
api_key=_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
|
request_data = {}
|
||||||
|
should_proceed = llm_guard.should_proceed(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=request_data
|
||||||
|
)
|
||||||
|
|
||||||
assert should_proceed == False
|
assert should_proceed == False
|
||||||
|
|
||||||
|
@ -120,6 +123,46 @@ def test_llm_guard_key_specific_mode():
|
||||||
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
|
request_data = {}
|
||||||
|
|
||||||
|
should_proceed = llm_guard.should_proceed(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=request_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert should_proceed == True
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_guard_request_specific_mode():
|
||||||
|
"""
|
||||||
|
Tests to see if llm guard 'request-specific' permissions work
|
||||||
|
"""
|
||||||
|
litellm.llm_guard_mode = "request-specific"
|
||||||
|
|
||||||
|
llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True)
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
# NOT ENABLED
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {}
|
||||||
|
|
||||||
|
should_proceed = llm_guard.should_proceed(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=request_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert should_proceed == False
|
||||||
|
|
||||||
|
# ENABLED
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {"metadata": {"permissions": {"enable_llm_guard_check": True}}}
|
||||||
|
|
||||||
|
should_proceed = llm_guard.should_proceed(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=request_data
|
||||||
|
)
|
||||||
|
|
||||||
assert should_proceed == True
|
assert should_proceed == True
|
||||||
|
|
|
@ -398,6 +398,40 @@ async def test_async_router_context_window_fallback():
|
||||||
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_rpm_pre_call_check():
|
||||||
|
"""
|
||||||
|
- for a given model not in model cost map
|
||||||
|
- with rpm set
|
||||||
|
- check if rpm check is run
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "fake-openai-endpoint", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "openai/my-fake-model",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
"api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/",
|
||||||
|
"rpm": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list, set_verbose=True, enable_pre_call_checks=True, num_retries=0) # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
router._pre_call_checks(
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
healthy_deployments=model_list,
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
pytest.fail("Expected this to fail")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_router_context_window_check_pre_call_check_in_group():
|
def test_router_context_window_check_pre_call_check_in_group():
|
||||||
"""
|
"""
|
||||||
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
|
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
|
||||||
|
@ -932,6 +966,35 @@ def test_openai_completion_on_router():
|
||||||
# test_openai_completion_on_router()
|
# test_openai_completion_on_router()
|
||||||
|
|
||||||
|
|
||||||
|
def test_consistent_model_id():
|
||||||
|
"""
|
||||||
|
- For a given model group + litellm params, assert the model id is always the same
|
||||||
|
|
||||||
|
Test on `_generate_model_id`
|
||||||
|
|
||||||
|
Test on `set_model_list`
|
||||||
|
|
||||||
|
Test on `_add_deployment`
|
||||||
|
"""
|
||||||
|
model_group = "gpt-3.5-turbo"
|
||||||
|
litellm_params = {
|
||||||
|
"model": "openai/my-fake-model",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
"api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/",
|
||||||
|
"stream_timeout": 0.001,
|
||||||
|
}
|
||||||
|
|
||||||
|
id1 = Router()._generate_model_id(
|
||||||
|
model_group=model_group, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
id2 = Router()._generate_model_id(
|
||||||
|
model_group=model_group, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
assert id1 == id2
|
||||||
|
|
||||||
|
|
||||||
def test_reading_keys_os_environ():
|
def test_reading_keys_os_environ():
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
|
|
@ -831,22 +831,25 @@ def test_bedrock_claude_3_streaming():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_claude_3_streaming_finish_reason():
|
@pytest.mark.asyncio
|
||||||
|
async def test_claude_3_streaming_finish_reason():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "Be helpful"},
|
{"role": "system", "content": "Be helpful"},
|
||||||
{"role": "user", "content": "What do you know?"},
|
{"role": "user", "content": "What do you know?"},
|
||||||
]
|
]
|
||||||
response: ModelResponse = completion( # type: ignore
|
response: ModelResponse = await litellm.acompletion( # type: ignore
|
||||||
model="claude-3-opus-20240229",
|
model="claude-3-opus-20240229",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to-check the response
|
||||||
num_finish_reason = 0
|
num_finish_reason = 0
|
||||||
for idx, chunk in enumerate(response):
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
if isinstance(chunk, ModelResponse):
|
if isinstance(chunk, ModelResponse):
|
||||||
if chunk.choices[0].finish_reason is not None:
|
if chunk.choices[0].finish_reason is not None:
|
||||||
num_finish_reason += 1
|
num_finish_reason += 1
|
||||||
|
@ -2285,7 +2288,7 @@ async def test_acompletion_claude_3_function_call_with_streaming():
|
||||||
elif chunk.choices[0].finish_reason is not None: # last chunk
|
elif chunk.choices[0].finish_reason is not None: # last chunk
|
||||||
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||||
idx += 1
|
idx += 1
|
||||||
# raise Exception("it worked!")
|
# raise Exception("it worked! ")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -32,5 +32,5 @@ class CompletionRequest(BaseModel):
|
||||||
model_list: Optional[List[str]] = None
|
model_list: Optional[List[str]] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
# allow kwargs
|
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
protected_namespaces = ()
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
from .embedding import EmbeddingRequest
|
from .embedding import EmbeddingRequest
|
||||||
import uuid
|
import uuid, enum
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
|
@ -12,6 +12,9 @@ class ModelConfig(BaseModel):
|
||||||
tpm: int
|
tpm: int
|
||||||
rpm: int
|
rpm: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
class RouterConfig(BaseModel):
|
class RouterConfig(BaseModel):
|
||||||
model_list: List[ModelConfig]
|
model_list: List[ModelConfig]
|
||||||
|
@ -41,6 +44,9 @@ class RouterConfig(BaseModel):
|
||||||
"latency-based-routing",
|
"latency-based-routing",
|
||||||
] = "simple-shuffle"
|
] = "simple-shuffle"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
id: Optional[
|
id: Optional[
|
||||||
|
@ -127,9 +133,11 @@ class Deployment(BaseModel):
|
||||||
litellm_params: LiteLLM_Params
|
litellm_params: LiteLLM_Params
|
||||||
model_info: ModelInfo
|
model_info: ModelInfo
|
||||||
|
|
||||||
def __init__(self, model_info: Optional[ModelInfo] = None, **params):
|
def __init__(self, model_info: Optional[Union[ModelInfo, dict]] = None, **params):
|
||||||
if model_info is None:
|
if model_info is None:
|
||||||
model_info = ModelInfo()
|
model_info = ModelInfo()
|
||||||
|
elif isinstance(model_info, dict):
|
||||||
|
model_info = ModelInfo(**model_info)
|
||||||
super().__init__(model_info=model_info, **params)
|
super().__init__(model_info=model_info, **params)
|
||||||
|
|
||||||
def to_json(self, **kwargs):
|
def to_json(self, **kwargs):
|
||||||
|
@ -141,6 +149,7 @@ class Deployment(BaseModel):
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
@ -157,3 +166,11 @@ class Deployment(BaseModel):
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
# Allow dictionary-style assignment of attributes
|
# Allow dictionary-style assignment of attributes
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class RouterErrors(enum.Enum):
|
||||||
|
"""
|
||||||
|
Enum for router specific errors with common codes
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_defined_ratelimit_error = "Deployment over user-defined ratelimit."
|
||||||
|
|
|
@ -20,6 +20,7 @@ import datetime, time
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import uuid
|
import uuid
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import textwrap
|
||||||
import logging
|
import logging
|
||||||
import asyncio, httpx, inspect
|
import asyncio, httpx, inspect
|
||||||
from inspect import iscoroutine
|
from inspect import iscoroutine
|
||||||
|
@ -236,6 +237,7 @@ class HiddenParams(OpenAIObject):
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
@ -605,7 +607,7 @@ class ModelResponse(OpenAIObject):
|
||||||
|
|
||||||
|
|
||||||
class Embedding(OpenAIObject):
|
class Embedding(OpenAIObject):
|
||||||
embedding: list = []
|
embedding: Union[list, str] = []
|
||||||
index: int
|
index: int
|
||||||
object: str
|
object: str
|
||||||
|
|
||||||
|
@ -1104,7 +1106,6 @@ class Logging:
|
||||||
curl_command = self.model_call_details
|
curl_command = self.model_call_details
|
||||||
|
|
||||||
# only print verbose if verbose logger is not set
|
# only print verbose if verbose logger is not set
|
||||||
|
|
||||||
if verbose_logger.level == 0:
|
if verbose_logger.level == 0:
|
||||||
# this means verbose logger was not switched on - user is in litellm.set_verbose=True
|
# this means verbose logger was not switched on - user is in litellm.set_verbose=True
|
||||||
print_verbose(f"\033[92m{curl_command}\033[0m\n")
|
print_verbose(f"\033[92m{curl_command}\033[0m\n")
|
||||||
|
@ -1989,9 +1990,6 @@ class Logging:
|
||||||
else:
|
else:
|
||||||
litellm.cache.add_cache(result, **kwargs)
|
litellm.cache.add_cache(result, **kwargs)
|
||||||
if isinstance(callback, CustomLogger): # custom logger class
|
if isinstance(callback, CustomLogger): # custom logger class
|
||||||
print_verbose(
|
|
||||||
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
|
|
||||||
)
|
|
||||||
if self.stream == True:
|
if self.stream == True:
|
||||||
if (
|
if (
|
||||||
"async_complete_streaming_response"
|
"async_complete_streaming_response"
|
||||||
|
@ -2375,7 +2373,6 @@ def client(original_function):
|
||||||
if litellm.use_client or (
|
if litellm.use_client or (
|
||||||
"use_client" in kwargs and kwargs["use_client"] == True
|
"use_client" in kwargs and kwargs["use_client"] == True
|
||||||
):
|
):
|
||||||
print_verbose(f"litedebugger initialized")
|
|
||||||
if "lite_debugger" not in litellm.input_callback:
|
if "lite_debugger" not in litellm.input_callback:
|
||||||
litellm.input_callback.append("lite_debugger")
|
litellm.input_callback.append("lite_debugger")
|
||||||
if "lite_debugger" not in litellm.success_callback:
|
if "lite_debugger" not in litellm.success_callback:
|
||||||
|
@ -2999,7 +2996,7 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
): # allow users to control returning cached responses from the completion function
|
): # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
# checking cache
|
||||||
print_verbose(f"INSIDE CHECKING CACHE")
|
print_verbose("INSIDE CHECKING CACHE")
|
||||||
if (
|
if (
|
||||||
litellm.cache is not None
|
litellm.cache is not None
|
||||||
and str(original_function.__name__)
|
and str(original_function.__name__)
|
||||||
|
@ -3106,6 +3103,22 @@ def client(original_function):
|
||||||
response_object=cached_result,
|
response_object=cached_result,
|
||||||
model_response_object=ModelResponse(),
|
model_response_object=ModelResponse(),
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
call_type == CallTypes.atext_completion.value
|
||||||
|
and isinstance(cached_result, dict)
|
||||||
|
):
|
||||||
|
if kwargs.get("stream", False) == True:
|
||||||
|
cached_result = convert_to_streaming_response_async(
|
||||||
|
response_object=cached_result,
|
||||||
|
)
|
||||||
|
cached_result = CustomStreamWrapper(
|
||||||
|
completion_stream=cached_result,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cached_result = TextCompletionResponse(**cached_result)
|
||||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||||
cached_result, dict
|
cached_result, dict
|
||||||
):
|
):
|
||||||
|
@ -3174,7 +3187,13 @@ def client(original_function):
|
||||||
for val in non_null_list:
|
for val in non_null_list:
|
||||||
idx, cr = val # (idx, cr) tuple
|
idx, cr = val # (idx, cr) tuple
|
||||||
if cr is not None:
|
if cr is not None:
|
||||||
final_embedding_cached_response.data[idx] = cr
|
final_embedding_cached_response.data[idx] = (
|
||||||
|
Embedding(
|
||||||
|
embedding=cr["embedding"],
|
||||||
|
index=idx,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
)
|
||||||
if len(remaining_list) == 0:
|
if len(remaining_list) == 0:
|
||||||
# LOG SUCCESS
|
# LOG SUCCESS
|
||||||
cache_hit = True
|
cache_hit = True
|
||||||
|
@ -4837,8 +4856,17 @@ def get_optional_params(
|
||||||
optional_params["top_p"] = top_p
|
optional_params["top_p"] = top_p
|
||||||
if stream:
|
if stream:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
if n is not None:
|
||||||
|
optional_params["candidate_count"] = n
|
||||||
|
if stop is not None:
|
||||||
|
if isinstance(stop, str):
|
||||||
|
optional_params["stop_sequences"] = [stop]
|
||||||
|
elif isinstance(stop, list):
|
||||||
|
optional_params["stop_sequences"] = stop
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
optional_params["max_output_tokens"] = max_tokens
|
||||||
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
if tools is not None and isinstance(tools, list):
|
if tools is not None and isinstance(tools, list):
|
||||||
from vertexai.preview import generative_models
|
from vertexai.preview import generative_models
|
||||||
|
|
||||||
|
@ -5525,6 +5553,9 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"stream",
|
"stream",
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
|
"response_format",
|
||||||
|
"n",
|
||||||
|
"stop",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
|
@ -5905,6 +5936,16 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
def get_utc_datetime():
|
||||||
|
import datetime as dt
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
if hasattr(dt, "UTC"):
|
||||||
|
return datetime.now(dt.UTC) # type: ignore
|
||||||
|
else:
|
||||||
|
return datetime.utcnow() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_max_tokens(model: str):
|
def get_max_tokens(model: str):
|
||||||
"""
|
"""
|
||||||
Get the maximum number of output tokens allowed for a given model.
|
Get the maximum number of output tokens allowed for a given model.
|
||||||
|
@ -6523,8 +6564,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
|
||||||
for detail in additional_details:
|
for detail in additional_details:
|
||||||
slack_msg += f"{detail}: {additional_details[detail]}\n"
|
slack_msg += f"{detail}: {additional_details[detail]}\n"
|
||||||
slack_msg += f"Traceback: {traceback_exception}"
|
slack_msg += f"Traceback: {traceback_exception}"
|
||||||
|
truncated_slack_msg = textwrap.shorten(slack_msg, width=512, placeholder="...")
|
||||||
slack_app.client.chat_postMessage(
|
slack_app.client.chat_postMessage(
|
||||||
channel=alerts_channel, text=slack_msg
|
channel=alerts_channel, text=truncated_slack_msg
|
||||||
)
|
)
|
||||||
elif callback == "sentry":
|
elif callback == "sentry":
|
||||||
capture_exception(exception)
|
capture_exception(exception)
|
||||||
|
@ -7741,7 +7783,7 @@ def exception_type(
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"429 Quota exceeded" in error_str
|
"429 Quota exceeded" in error_str
|
||||||
or "IndexError: list index out of range"
|
or "IndexError: list index out of range" in error_str
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise RateLimitError(
|
raise RateLimitError(
|
||||||
|
@ -8764,7 +8806,9 @@ class CustomStreamWrapper:
|
||||||
return hold, curr_chunk
|
return hold, curr_chunk
|
||||||
|
|
||||||
def handle_anthropic_chunk(self, chunk):
|
def handle_anthropic_chunk(self, chunk):
|
||||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
str_line = chunk
|
||||||
|
if isinstance(chunk, bytes): # Handle binary data
|
||||||
|
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
@ -10024,6 +10068,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "custom_openai"
|
or self.custom_llm_provider == "custom_openai"
|
||||||
or self.custom_llm_provider == "text-completion-openai"
|
or self.custom_llm_provider == "text-completion-openai"
|
||||||
or self.custom_llm_provider == "azure_text"
|
or self.custom_llm_provider == "azure_text"
|
||||||
|
or self.custom_llm_provider == "anthropic"
|
||||||
or self.custom_llm_provider == "huggingface"
|
or self.custom_llm_provider == "huggingface"
|
||||||
or self.custom_llm_provider == "ollama"
|
or self.custom_llm_provider == "ollama"
|
||||||
or self.custom_llm_provider == "ollama_chat"
|
or self.custom_llm_provider == "ollama_chat"
|
||||||
|
|
|
@ -66,6 +66,28 @@
|
||||||
"litellm_provider": "openai",
|
"litellm_provider": "openai",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"gpt-4-turbo": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true
|
||||||
|
},
|
||||||
|
"gpt-4-turbo-2024-04-09": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true
|
||||||
|
},
|
||||||
"gpt-4-1106-preview": {
|
"gpt-4-1106-preview": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -948,6 +970,28 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
|
"gemini-1.0-pro-001": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_input_tokens": 32760,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.00000025,
|
||||||
|
"output_cost_per_token": 0.0000005,
|
||||||
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
|
},
|
||||||
|
"gemini-1.0-pro-002": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_input_tokens": 32760,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.00000025,
|
||||||
|
"output_cost_per_token": 0.0000005,
|
||||||
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
|
},
|
||||||
"gemini-1.5-pro": {
|
"gemini-1.5-pro": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 1000000,
|
"max_input_tokens": 1000000,
|
||||||
|
@ -970,6 +1014,17 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
|
"gemini-1.5-pro-preview-0409": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_input_tokens": 1000000,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
|
},
|
||||||
"gemini-experimental": {
|
"gemini-experimental": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 1000000,
|
"max_input_tokens": 1000000,
|
||||||
|
@ -2808,6 +2863,46 @@
|
||||||
"output_cost_per_token": 0.000000,
|
"output_cost_per_token": 0.000000,
|
||||||
"litellm_provider": "voyage",
|
"litellm_provider": "voyage",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-large-2": {
|
||||||
|
"max_tokens": 16000,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"input_cost_per_token": 0.00000012,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-law-2": {
|
||||||
|
"max_tokens": 16000,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"input_cost_per_token": 0.00000012,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-code-2": {
|
||||||
|
"max_tokens": 16000,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"input_cost_per_token": 0.00000012,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-2": {
|
||||||
|
"max_tokens": 4000,
|
||||||
|
"max_input_tokens": 4000,
|
||||||
|
"input_cost_per_token": 0.0000001,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"voyage/voyage-lite-02-instruct": {
|
||||||
|
"max_tokens": 4000,
|
||||||
|
"max_input_tokens": 4000,
|
||||||
|
"input_cost_per_token": 0.0000001,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "voyage",
|
||||||
|
"mode": "embedding"
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,16 @@ model_list:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
- model_name: fake-openai-endpoint-2
|
||||||
|
litellm_params:
|
||||||
|
model: openai/my-fake-model
|
||||||
|
api_key: my-fake-key
|
||||||
|
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||||
|
stream_timeout: 0.001
|
||||||
|
rpm: 1
|
||||||
|
- model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model
|
||||||
|
litellm_params:
|
||||||
|
model: text-completion-openai/gpt-3.5-turbo-instruct
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
drop_params: True
|
drop_params: True
|
||||||
# max_budget: 100
|
# max_budget: 100
|
||||||
|
@ -58,6 +67,13 @@ litellm_settings:
|
||||||
telemetry: False
|
telemetry: False
|
||||||
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
routing_strategy: usage-based-routing-v2
|
||||||
|
redis_host: os.environ/REDIS_HOST
|
||||||
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
|
redis_port: os.environ/REDIS_PORT
|
||||||
|
enable_pre_call_checks: true
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||||
store_model_in_db: True
|
store_model_in_db: True
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.34.33"
|
version = "1.35.4"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.34.33"
|
version = "1.35.4"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,9 +14,9 @@ pandas==2.1.1 # for viewing clickhouse spend analytics
|
||||||
prisma==0.11.0 # for db
|
prisma==0.11.0 # for db
|
||||||
mangum==0.17.0 # for aws lambda functions
|
mangum==0.17.0 # for aws lambda functions
|
||||||
pynacl==1.5.0 # for encrypting keys
|
pynacl==1.5.0 # for encrypting keys
|
||||||
google-cloud-aiplatform==1.43.0 # for vertex ai calls
|
google-cloud-aiplatform==1.47.0 # for vertex ai calls
|
||||||
anthropic[vertex]==0.21.3
|
anthropic[vertex]==0.21.3
|
||||||
google-generativeai==0.3.2 # for vertex ai calls
|
google-generativeai==0.5.0 # for vertex ai calls
|
||||||
async_generator==1.10.0 # for async ollama calls
|
async_generator==1.10.0 # for async ollama calls
|
||||||
langfuse>=2.6.3 # for langfuse self-hosted logging
|
langfuse>=2.6.3 # for langfuse self-hosted logging
|
||||||
datadog-api-client==2.23.0 # for datadog logging
|
datadog-api-client==2.23.0 # for datadog logging
|
||||||
|
|
|
@ -53,6 +53,7 @@ model LiteLLM_OrganizationTable {
|
||||||
updated_by String
|
updated_by String
|
||||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
teams LiteLLM_TeamTable[]
|
teams LiteLLM_TeamTable[]
|
||||||
|
users LiteLLM_UserTable[]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model info for teams, just has model aliases for now.
|
// Model info for teams, just has model aliases for now.
|
||||||
|
@ -99,6 +100,7 @@ model LiteLLM_UserTable {
|
||||||
user_id String @id
|
user_id String @id
|
||||||
user_alias String?
|
user_alias String?
|
||||||
team_id String?
|
team_id String?
|
||||||
|
organization_id String?
|
||||||
teams String[] @default([])
|
teams String[] @default([])
|
||||||
user_role String?
|
user_role String?
|
||||||
max_budget Float?
|
max_budget Float?
|
||||||
|
@ -113,6 +115,7 @@ model LiteLLM_UserTable {
|
||||||
allowed_cache_controls String[] @default([])
|
allowed_cache_controls String[] @default([])
|
||||||
model_spend Json @default("{}")
|
model_spend Json @default("{}")
|
||||||
model_max_budget Json @default("{}")
|
model_max_budget Json @default("{}")
|
||||||
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate Tokens for Proxy
|
// Generate Tokens for Proxy
|
||||||
|
|
|
@ -127,20 +127,6 @@ async def chat_completion(session, key):
|
||||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_models():
|
|
||||||
"""
|
|
||||||
Add model
|
|
||||||
Call new model
|
|
||||||
"""
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
key_gen = await generate_key(session=session)
|
|
||||||
key = key_gen["key"]
|
|
||||||
await add_models(session=session)
|
|
||||||
await asyncio.sleep(60)
|
|
||||||
await chat_completion(session=session, key=key)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_models():
|
async def test_get_models():
|
||||||
"""
|
"""
|
||||||
|
@ -178,14 +164,15 @@ async def delete_model(session, model_id="123"):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_models():
|
async def test_add_and_delete_models():
|
||||||
"""
|
"""
|
||||||
Get models user has access to
|
Add model
|
||||||
|
Call new model
|
||||||
"""
|
"""
|
||||||
model_id = "12345"
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
key_gen = await generate_key(session=session)
|
key_gen = await generate_key(session=session)
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
|
model_id = "1234"
|
||||||
await add_models(session=session, model_id=model_id)
|
await add_models(session=session, model_id=model_id)
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
await chat_completion(session=session, key=key)
|
await chat_completion(session=session, key=key)
|
||||||
|
|
|
@ -18,7 +18,12 @@ async def generate_key(session):
|
||||||
url = "http://0.0.0.0:4000/key/generate"
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {
|
data = {
|
||||||
"models": ["gpt-4", "text-embedding-ada-002", "dall-e-2"],
|
"models": [
|
||||||
|
"gpt-4",
|
||||||
|
"text-embedding-ada-002",
|
||||||
|
"dall-e-2",
|
||||||
|
"fake-openai-endpoint-2",
|
||||||
|
],
|
||||||
"duration": None,
|
"duration": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,14 +68,14 @@ async def new_user(session):
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion(session, key):
|
async def chat_completion(session, key, model="gpt-4"):
|
||||||
url = "http://0.0.0.0:4000/chat/completions"
|
url = "http://0.0.0.0:4000/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
data = {
|
data = {
|
||||||
"model": "gpt-4",
|
"model": model,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Hello!"},
|
{"role": "user", "content": "Hello!"},
|
||||||
|
@ -189,6 +194,31 @@ async def test_chat_completion():
|
||||||
await chat_completion(session=session, key=key_2)
|
await chat_completion(session=session, key=key_2)
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_ratelimit():
|
||||||
|
"""
|
||||||
|
- call model with rpm 1
|
||||||
|
- make 2 parallel calls
|
||||||
|
- make sure 1 fails
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# key_gen = await generate_key(session=session)
|
||||||
|
key = "sk-1234"
|
||||||
|
tasks = []
|
||||||
|
tasks.append(
|
||||||
|
chat_completion(session=session, key=key, model="fake-openai-endpoint-2")
|
||||||
|
)
|
||||||
|
tasks.append(
|
||||||
|
chat_completion(session=session, key=key, model="fake-openai-endpoint-2")
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
pytest.fail("Expected at least 1 call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_old_key():
|
async def test_chat_completion_old_key():
|
||||||
"""
|
"""
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
||||||
|
!function(){"use strict";var e,t,n,r,o,u,i,c,f,a={},l={};function d(e){var t=l[e];if(void 0!==t)return t.exports;var n=l[e]={id:e,loaded:!1,exports:{}},r=!0;try{a[e](n,n.exports,d),r=!1}finally{r&&delete l[e]}return n.loaded=!0,n.exports}d.m=a,e=[],d.O=function(t,n,r,o){if(n){o=o||0;for(var u=e.length;u>0&&e[u-1][2]>o;u--)e[u]=e[u-1];e[u]=[n,r,o];return}for(var i=1/0,u=0;u<e.length;u++){for(var n=e[u][0],r=e[u][1],o=e[u][2],c=!0,f=0;f<n.length;f++)i>=o&&Object.keys(d.O).every(function(e){return d.O[e](n[f])})?n.splice(f--,1):(c=!1,o<i&&(i=o));if(c){e.splice(u--,1);var a=r();void 0!==a&&(t=a)}}return t},d.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return d.d(t,{a:t}),t},n=Object.getPrototypeOf?function(e){return Object.getPrototypeOf(e)}:function(e){return e.__proto__},d.t=function(e,r){if(1&r&&(e=this(e)),8&r||"object"==typeof e&&e&&(4&r&&e.__esModule||16&r&&"function"==typeof e.then))return e;var o=Object.create(null);d.r(o);var u={};t=t||[null,n({}),n([]),n(n)];for(var i=2&r&&e;"object"==typeof i&&!~t.indexOf(i);i=n(i))Object.getOwnPropertyNames(i).forEach(function(t){u[t]=function(){return e[t]}});return u.default=function(){return e},d.d(o,u),o},d.d=function(e,t){for(var n in t)d.o(t,n)&&!d.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},d.f={},d.e=function(e){return Promise.all(Object.keys(d.f).reduce(function(t,n){return d.f[n](e,t),t},[]))},d.u=function(e){},d.miniCssF=function(e){return"static/css/a282d1bfd6ed4df8.css"},d.g=function(){if("object"==typeof globalThis)return globalThis;try{return this||Function("return this")()}catch(e){if("object"==typeof window)return window}}(),d.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r={},o="_N_E:",d.l=function(e,t,n,u){if(r[e]){r[e].push(t);return}if(void 0!==n)for(var i,c,f=document.getElementsByTagName("script"),a=0;a<f.length;a++){var l=f[a];if(l.getAttribute("src")==e||l.getAttribute("data-webpack")==o+n){i=l;break}}i||(c=!0,(i=document.createElement("script")).charset="utf-8",i.timeout=120,d.nc&&i.setAttribute("nonce",d.nc),i.setAttribute("data-webpack",o+n),i.src=d.tu(e)),r[e]=[t];var s=function(t,n){i.onerror=i.onload=null,clearTimeout(p);var o=r[e];if(delete r[e],i.parentNode&&i.parentNode.removeChild(i),o&&o.forEach(function(e){return e(n)}),t)return t(n)},p=setTimeout(s.bind(null,void 0,{type:"timeout",target:i}),12e4);i.onerror=s.bind(null,i.onerror),i.onload=s.bind(null,i.onload),c&&document.head.appendChild(i)},d.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},d.nmd=function(e){return e.paths=[],e.children||(e.children=[]),e},d.tt=function(){return void 0===u&&(u={createScriptURL:function(e){return e}},"undefined"!=typeof trustedTypes&&trustedTypes.createPolicy&&(u=trustedTypes.createPolicy("nextjs#bundler",u))),u},d.tu=function(e){return d.tt().createScriptURL(e)},d.p="/ui/_next/",i={272:0},d.f.j=function(e,t){var n=d.o(i,e)?i[e]:void 0;if(0!==n){if(n)t.push(n[2]);else if(272!=e){var r=new Promise(function(t,r){n=i[e]=[t,r]});t.push(n[2]=r);var o=d.p+d.u(e),u=Error();d.l(o,function(t){if(d.o(i,e)&&(0!==(n=i[e])&&(i[e]=void 0),n)){var r=t&&("load"===t.type?"missing":t.type),o=t&&t.target&&t.target.src;u.message="Loading chunk "+e+" failed.\n("+r+": "+o+")",u.name="ChunkLoadError",u.type=r,u.request=o,n[1](u)}},"chunk-"+e,e)}else i[e]=0}},d.O.j=function(e){return 0===i[e]},c=function(e,t){var n,r,o=t[0],u=t[1],c=t[2],f=0;if(o.some(function(e){return 0!==i[e]})){for(n in u)d.o(u,n)&&(d.m[n]=u[n]);if(c)var a=c(d)}for(e&&e(t);f<o.length;f++)r=o[f],d.o(i,r)&&i[r]&&i[r][0](),i[r]=0;return d.O(a)},(f=self.webpackChunk_N_E=self.webpackChunk_N_E||[]).forEach(c.bind(null,0)),f.push=c.bind(null,f.push.bind(f))}();
|
File diff suppressed because one or more lines are too long
|
@ -1 +1 @@
|
||||||
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-68f14392aea51f63.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-a507ee9e75a3be72.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-589b47e7a69d316f.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-096338c8e1915716.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-68f14392aea51f63.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/04eb0ce8764f86fe.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[46502,[\"253\",\"static/chunks/253-8ab6133ad5f92675.js\",\"931\",\"static/chunks/app/page-a485c9c659128852.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/04eb0ce8764f86fe.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"KnyD0lgLk9_a0erHwSSu-\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-11b043d6a7ef78fa.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-a507ee9e75a3be72.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-589b47e7a69d316f.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-096338c8e1915716.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-11b043d6a7ef78fa.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/a282d1bfd6ed4df8.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[29306,[\"823\",\"static/chunks/823-2ada48e2e6a5ab39.js\",\"931\",\"static/chunks/app/page-e16bcf8bdc356530.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/a282d1bfd6ed4df8.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"BNBzATtnAelV8BpmzRdfL\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
|
@ -1,7 +1,7 @@
|
||||||
2:I[77831,[],""]
|
2:I[77831,[],""]
|
||||||
3:I[46502,["253","static/chunks/253-8ab6133ad5f92675.js","931","static/chunks/app/page-a485c9c659128852.js"],""]
|
3:I[29306,["823","static/chunks/823-2ada48e2e6a5ab39.js","931","static/chunks/app/page-e16bcf8bdc356530.js"],""]
|
||||||
4:I[5613,[],""]
|
4:I[5613,[],""]
|
||||||
5:I[31778,[],""]
|
5:I[31778,[],""]
|
||||||
0:["KnyD0lgLk9_a0erHwSSu-",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/04eb0ce8764f86fe.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
0:["BNBzATtnAelV8BpmzRdfL",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/a282d1bfd6ed4df8.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
||||||
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
|
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
|
||||||
1:null
|
1:null
|
||||||
|
|
|
@ -7,6 +7,7 @@ import ModelDashboard from "@/components/model_dashboard";
|
||||||
import ViewUserDashboard from "@/components/view_users";
|
import ViewUserDashboard from "@/components/view_users";
|
||||||
import Teams from "@/components/teams";
|
import Teams from "@/components/teams";
|
||||||
import AdminPanel from "@/components/admins";
|
import AdminPanel from "@/components/admins";
|
||||||
|
import Settings from "@/components/settings";
|
||||||
import ChatUI from "@/components/chat_ui";
|
import ChatUI from "@/components/chat_ui";
|
||||||
import Sidebar from "../components/leftnav";
|
import Sidebar from "../components/leftnav";
|
||||||
import Usage from "../components/usage";
|
import Usage from "../components/usage";
|
||||||
|
@ -160,6 +161,13 @@ const CreateKeyPage = () => {
|
||||||
setTeams={setTeams}
|
setTeams={setTeams}
|
||||||
searchParams={searchParams}
|
searchParams={searchParams}
|
||||||
accessToken={accessToken}
|
accessToken={accessToken}
|
||||||
|
showSSOBanner={showSSOBanner}
|
||||||
|
/>
|
||||||
|
) : page == "settings" ? (
|
||||||
|
<Settings
|
||||||
|
userID={userID}
|
||||||
|
userRole={userRole}
|
||||||
|
accessToken={accessToken}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<Usage
|
<Usage
|
||||||
|
|
|
@ -27,23 +27,27 @@ import {
|
||||||
Col,
|
Col,
|
||||||
Text,
|
Text,
|
||||||
Grid,
|
Grid,
|
||||||
|
Callout,
|
||||||
} from "@tremor/react";
|
} from "@tremor/react";
|
||||||
import { CogIcon } from "@heroicons/react/outline";
|
import { PencilAltIcon } from "@heroicons/react/outline";
|
||||||
interface AdminPanelProps {
|
interface AdminPanelProps {
|
||||||
searchParams: any;
|
searchParams: any;
|
||||||
accessToken: string | null;
|
accessToken: string | null;
|
||||||
setTeams: React.Dispatch<React.SetStateAction<Object[] | null>>;
|
setTeams: React.Dispatch<React.SetStateAction<Object[] | null>>;
|
||||||
|
showSSOBanner: boolean;
|
||||||
}
|
}
|
||||||
import {
|
import {
|
||||||
userUpdateUserCall,
|
userUpdateUserCall,
|
||||||
Member,
|
Member,
|
||||||
userGetAllUsersCall,
|
userGetAllUsersCall,
|
||||||
User,
|
User,
|
||||||
|
setCallbacksCall,
|
||||||
} from "./networking";
|
} from "./networking";
|
||||||
|
|
||||||
const AdminPanel: React.FC<AdminPanelProps> = ({
|
const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||||
searchParams,
|
searchParams,
|
||||||
accessToken,
|
accessToken,
|
||||||
|
showSSOBanner
|
||||||
}) => {
|
}) => {
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
const [memberForm] = Form.useForm();
|
const [memberForm] = Form.useForm();
|
||||||
|
@ -52,6 +56,47 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||||
const [admins, setAdmins] = useState<null | any[]>(null);
|
const [admins, setAdmins] = useState<null | any[]>(null);
|
||||||
|
|
||||||
const [isAddMemberModalVisible, setIsAddMemberModalVisible] = useState(false);
|
const [isAddMemberModalVisible, setIsAddMemberModalVisible] = useState(false);
|
||||||
|
const [isAddAdminModalVisible, setIsAddAdminModalVisible] = useState(false);
|
||||||
|
const [isUpdateMemberModalVisible, setIsUpdateModalModalVisible] = useState(false);
|
||||||
|
const [isAddSSOModalVisible, setIsAddSSOModalVisible] = useState(false);
|
||||||
|
const [isInstructionsModalVisible, setIsInstructionsModalVisible] = useState(false);
|
||||||
|
|
||||||
|
let nonSssoUrl;
|
||||||
|
try {
|
||||||
|
nonSssoUrl = window.location.origin;
|
||||||
|
} catch (error) {
|
||||||
|
nonSssoUrl = '<your-proxy-url>';
|
||||||
|
}
|
||||||
|
nonSssoUrl += '/fallback/login';
|
||||||
|
|
||||||
|
const handleAddSSOOk = () => {
|
||||||
|
|
||||||
|
setIsAddSSOModalVisible(false);
|
||||||
|
form.resetFields();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleAddSSOCancel = () => {
|
||||||
|
setIsAddSSOModalVisible(false);
|
||||||
|
form.resetFields();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleShowInstructions = (formValues: Record<string, any>) => {
|
||||||
|
handleAdminCreate(formValues);
|
||||||
|
handleSSOUpdate(formValues);
|
||||||
|
setIsAddSSOModalVisible(false);
|
||||||
|
setIsInstructionsModalVisible(true);
|
||||||
|
// Optionally, you can call handleSSOUpdate here with the formValues
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleInstructionsOk = () => {
|
||||||
|
setIsInstructionsModalVisible(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleInstructionsCancel = () => {
|
||||||
|
setIsInstructionsModalVisible(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const roles = ["proxy_admin", "proxy_admin_viewer"]
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// Fetch model info and set the default selected model
|
// Fetch model info and set the default selected model
|
||||||
|
@ -94,26 +139,138 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||||
fetchProxyAdminInfo();
|
fetchProxyAdminInfo();
|
||||||
}, [accessToken]);
|
}, [accessToken]);
|
||||||
|
|
||||||
|
const handleMemberUpdateOk = () => {
|
||||||
|
setIsUpdateModalModalVisible(false);
|
||||||
|
memberForm.resetFields();
|
||||||
|
};
|
||||||
|
|
||||||
const handleMemberOk = () => {
|
const handleMemberOk = () => {
|
||||||
setIsAddMemberModalVisible(false);
|
setIsAddMemberModalVisible(false);
|
||||||
memberForm.resetFields();
|
memberForm.resetFields();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleAdminOk = () => {
|
||||||
|
setIsAddAdminModalVisible(false);
|
||||||
|
memberForm.resetFields();
|
||||||
|
};
|
||||||
|
|
||||||
const handleMemberCancel = () => {
|
const handleMemberCancel = () => {
|
||||||
setIsAddMemberModalVisible(false);
|
setIsAddMemberModalVisible(false);
|
||||||
memberForm.resetFields();
|
memberForm.resetFields();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleAdminCancel = () => {
|
||||||
|
setIsAddAdminModalVisible(false);
|
||||||
|
memberForm.resetFields();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleMemberUpdateCancel = () => {
|
||||||
|
setIsUpdateModalModalVisible(false);
|
||||||
|
memberForm.resetFields();
|
||||||
|
}
|
||||||
|
// Define the type for the handleMemberCreate function
|
||||||
|
type HandleMemberCreate = (formValues: Record<string, any>) => Promise<void>;
|
||||||
|
|
||||||
|
const addMemberForm = (handleMemberCreate: HandleMemberCreate,) => {
|
||||||
|
return <Form
|
||||||
|
form={form}
|
||||||
|
onFinish={handleMemberCreate}
|
||||||
|
labelCol={{ span: 8 }}
|
||||||
|
wrapperCol={{ span: 16 }}
|
||||||
|
labelAlign="left"
|
||||||
|
>
|
||||||
|
<>
|
||||||
|
<Form.Item label="Email" name="user_email" className="mb-4">
|
||||||
|
<Input
|
||||||
|
name="user_email"
|
||||||
|
className="px-3 py-2 border rounded-md w-full"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
<div className="text-center mb-4">OR</div>
|
||||||
|
<Form.Item label="User ID" name="user_id" className="mb-4">
|
||||||
|
<Input
|
||||||
|
name="user_id"
|
||||||
|
className="px-3 py-2 border rounded-md w-full"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
</>
|
||||||
|
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||||
|
<Button2 htmlType="submit">Add member</Button2>
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
}
|
||||||
|
|
||||||
|
const modifyMemberForm = (handleMemberUpdate: HandleMemberCreate, currentRole: string, userID: string) => {
|
||||||
|
return <Form
|
||||||
|
form={form}
|
||||||
|
onFinish={handleMemberUpdate}
|
||||||
|
labelCol={{ span: 8 }}
|
||||||
|
wrapperCol={{ span: 16 }}
|
||||||
|
labelAlign="left"
|
||||||
|
>
|
||||||
|
<>
|
||||||
|
<Form.Item rules={[{ required: true, message: 'Required' }]} label="User Role" name="user_role" labelCol={{ span: 10 }} labelAlign="left">
|
||||||
|
<Select value={currentRole}>
|
||||||
|
{roles.map((role, index) => (
|
||||||
|
<SelectItem
|
||||||
|
key={index}
|
||||||
|
value={role}
|
||||||
|
>
|
||||||
|
{role}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
label="Team ID"
|
||||||
|
name="user_id"
|
||||||
|
hidden={true}
|
||||||
|
initialValue={userID}
|
||||||
|
valuePropName="user_id"
|
||||||
|
className="mt-8"
|
||||||
|
>
|
||||||
|
<Input value={userID} disabled />
|
||||||
|
</Form.Item>
|
||||||
|
</>
|
||||||
|
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||||
|
<Button2 htmlType="submit">Update role</Button2>
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleMemberUpdate = async (formValues: Record<string, any>) => {
|
||||||
|
try{
|
||||||
|
if (accessToken != null && admins != null) {
|
||||||
|
message.info("Making API Call");
|
||||||
|
const response: any = await userUpdateUserCall(accessToken, formValues, null);
|
||||||
|
console.log(`response for team create call: ${response}`);
|
||||||
|
// Checking if the team exists in the list and updating or adding accordingly
|
||||||
|
const foundIndex = admins.findIndex((user) => {
|
||||||
|
console.log(
|
||||||
|
`user.user_id=${user.user_id}; response.user_id=${response.user_id}`
|
||||||
|
);
|
||||||
|
return user.user_id === response.user_id;
|
||||||
|
});
|
||||||
|
console.log(`foundIndex: ${foundIndex}`);
|
||||||
|
if (foundIndex == -1) {
|
||||||
|
console.log(`updates admin with new user`);
|
||||||
|
admins.push(response);
|
||||||
|
// If new user is found, update it
|
||||||
|
setAdmins(admins); // Set the new state
|
||||||
|
}
|
||||||
|
message.success("Refresh tab to see updated user role")
|
||||||
|
setIsUpdateModalModalVisible(false);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error creating the key:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const handleMemberCreate = async (formValues: Record<string, any>) => {
|
const handleMemberCreate = async (formValues: Record<string, any>) => {
|
||||||
try {
|
try {
|
||||||
if (accessToken != null && admins != null) {
|
if (accessToken != null && admins != null) {
|
||||||
message.info("Making API Call");
|
message.info("Making API Call");
|
||||||
const user_role: Member = {
|
const response: any = await userUpdateUserCall(accessToken, formValues, "proxy_admin_viewer");
|
||||||
role: "user",
|
|
||||||
user_email: formValues.user_email,
|
|
||||||
user_id: formValues.user_id,
|
|
||||||
};
|
|
||||||
const response: any = await userUpdateUserCall(accessToken, formValues);
|
|
||||||
console.log(`response for team create call: ${response}`);
|
console.log(`response for team create call: ${response}`);
|
||||||
// Checking if the team exists in the list and updating or adding accordingly
|
// Checking if the team exists in the list and updating or adding accordingly
|
||||||
const foundIndex = admins.findIndex((user) => {
|
const foundIndex = admins.findIndex((user) => {
|
||||||
|
@ -135,18 +292,66 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||||
console.error("Error creating the key:", error);
|
console.error("Error creating the key:", error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
const handleAdminCreate = async (formValues: Record<string, any>) => {
|
||||||
|
try {
|
||||||
|
if (accessToken != null && admins != null) {
|
||||||
|
message.info("Making API Call");
|
||||||
|
const user_role: Member = {
|
||||||
|
role: "user",
|
||||||
|
user_email: formValues.user_email,
|
||||||
|
user_id: formValues.user_id,
|
||||||
|
};
|
||||||
|
const response: any = await userUpdateUserCall(accessToken, formValues, "proxy_admin");
|
||||||
|
console.log(`response for team create call: ${response}`);
|
||||||
|
// Checking if the team exists in the list and updating or adding accordingly
|
||||||
|
const foundIndex = admins.findIndex((user) => {
|
||||||
|
console.log(
|
||||||
|
`user.user_id=${user.user_id}; response.user_id=${response.user_id}`
|
||||||
|
);
|
||||||
|
return user.user_id === response.user_id;
|
||||||
|
});
|
||||||
|
console.log(`foundIndex: ${foundIndex}`);
|
||||||
|
if (foundIndex == -1) {
|
||||||
|
console.log(`updates admin with new user`);
|
||||||
|
admins.push(response);
|
||||||
|
// If new user is found, update it
|
||||||
|
setAdmins(admins); // Set the new state
|
||||||
|
}
|
||||||
|
setIsAddAdminModalVisible(false);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error creating the key:", error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSSOUpdate = async (formValues: Record<string, any>) => {
|
||||||
|
if (accessToken == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let payload = {
|
||||||
|
environment_variables: {
|
||||||
|
PROXY_BASE_URL: formValues.proxy_base_url,
|
||||||
|
GOOGLE_CLIENT_ID: formValues.google_client_id,
|
||||||
|
GOOGLE_CLIENT_SECRET: formValues.google_client_secret,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
setCallbacksCall(accessToken, payload);
|
||||||
|
}
|
||||||
console.log(`admins: ${admins?.length}`);
|
console.log(`admins: ${admins?.length}`);
|
||||||
return (
|
return (
|
||||||
<div className="w-full m-2 mt-2 p-8">
|
<div className="w-full m-2 mt-2 p-8">
|
||||||
<Title level={4}>Restricted Access</Title>
|
<Title level={4}>Admin Access </Title>
|
||||||
<Paragraph>
|
<Paragraph>
|
||||||
Add other people to just view spend. They cannot create keys, teams or
|
{
|
||||||
|
showSSOBanner && <a href="https://docs.litellm.ai/docs/proxy/ui#restrict-ui-access">Requires SSO Setup</a>
|
||||||
|
}
|
||||||
|
<br/>
|
||||||
|
<b>Proxy Admin: </b> Can create keys, teams, users, add models, etc. <br/>
|
||||||
|
<b>Proxy Admin Viewer: </b>Can just view spend. They cannot create keys, teams or
|
||||||
grant users access to new models.{" "}
|
grant users access to new models.{" "}
|
||||||
<a href="https://docs.litellm.ai/docs/proxy/ui#restrict-ui-access">
|
|
||||||
Requires SSO Setup
|
|
||||||
</a>
|
|
||||||
</Paragraph>
|
</Paragraph>
|
||||||
<Grid numItems={1} className="gap-2 p-2 w-full">
|
<Grid numItems={1} className="gap-2 p-2 w-full">
|
||||||
|
|
||||||
<Col numColSpan={1}>
|
<Col numColSpan={1}>
|
||||||
<Card className="w-full mx-auto flex-auto overflow-y-auto max-h-[50vh]">
|
<Card className="w-full mx-auto flex-auto overflow-y-auto max-h-[50vh]">
|
||||||
<Table>
|
<Table>
|
||||||
|
@ -154,7 +359,6 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||||
<TableRow>
|
<TableRow>
|
||||||
<TableHeaderCell>Member Name</TableHeaderCell>
|
<TableHeaderCell>Member Name</TableHeaderCell>
|
||||||
<TableHeaderCell>Role</TableHeaderCell>
|
<TableHeaderCell>Role</TableHeaderCell>
|
||||||
{/* <TableHeaderCell>Action</TableHeaderCell> */}
|
|
||||||
</TableRow>
|
</TableRow>
|
||||||
</TableHead>
|
</TableHead>
|
||||||
|
|
||||||
|
@ -170,9 +374,18 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||||
: null}
|
: null}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>{member["user_role"]}</TableCell>
|
<TableCell>{member["user_role"]}</TableCell>
|
||||||
{/* <TableCell>
|
<TableCell>
|
||||||
<Icon icon={CogIcon} size="sm" />
|
<Icon icon={PencilAltIcon} size="sm" onClick={() => setIsUpdateModalModalVisible(true)}/>
|
||||||
</TableCell> */}
|
<Modal
|
||||||
|
title="Update role"
|
||||||
|
visible={isUpdateMemberModalVisible}
|
||||||
|
width={800}
|
||||||
|
footer={null}
|
||||||
|
onOk={handleMemberUpdateOk}
|
||||||
|
onCancel={handleMemberUpdateCancel}>
|
||||||
|
{modifyMemberForm(handleMemberUpdate, member["user_role"], member["user_id"])}
|
||||||
|
</Modal>
|
||||||
|
</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
))
|
))
|
||||||
: null}
|
: null}
|
||||||
|
@ -181,11 +394,27 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||||
</Card>
|
</Card>
|
||||||
</Col>
|
</Col>
|
||||||
<Col numColSpan={1}>
|
<Col numColSpan={1}>
|
||||||
|
<div className="flex justify-start">
|
||||||
<Button
|
<Button
|
||||||
className="mx-auto mb-5"
|
className="mr-4 mb-5"
|
||||||
onClick={() => setIsAddMemberModalVisible(true)}
|
onClick={() => setIsAddAdminModalVisible(true)}
|
||||||
>
|
>
|
||||||
+ Add viewer
|
+ Add admin
|
||||||
|
</Button>
|
||||||
|
<Modal
|
||||||
|
title="Add admin"
|
||||||
|
visible={isAddAdminModalVisible}
|
||||||
|
width={800}
|
||||||
|
footer={null}
|
||||||
|
onOk={handleAdminOk}
|
||||||
|
onCancel={handleAdminCancel}>
|
||||||
|
{addMemberForm(handleAdminCreate)}
|
||||||
|
</Modal>
|
||||||
|
<Button
|
||||||
|
className="mb-5"
|
||||||
|
onClick={() => setIsAddMemberModalVisible(true)}
|
||||||
|
>
|
||||||
|
+ Add viewer
|
||||||
</Button>
|
</Button>
|
||||||
<Modal
|
<Modal
|
||||||
title="Add viewer"
|
title="Add viewer"
|
||||||
|
@ -195,35 +424,99 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||||
onOk={handleMemberOk}
|
onOk={handleMemberOk}
|
||||||
onCancel={handleMemberCancel}
|
onCancel={handleMemberCancel}
|
||||||
>
|
>
|
||||||
<Form
|
{addMemberForm(handleMemberCreate)}
|
||||||
form={form}
|
|
||||||
onFinish={handleMemberCreate}
|
|
||||||
labelCol={{ span: 8 }}
|
|
||||||
wrapperCol={{ span: 16 }}
|
|
||||||
labelAlign="left"
|
|
||||||
>
|
|
||||||
<>
|
|
||||||
<Form.Item label="Email" name="user_email" className="mb-4">
|
|
||||||
<Input
|
|
||||||
name="user_email"
|
|
||||||
className="px-3 py-2 border rounded-md w-full"
|
|
||||||
/>
|
|
||||||
</Form.Item>
|
|
||||||
<div className="text-center mb-4">OR</div>
|
|
||||||
<Form.Item label="User ID" name="user_id" className="mb-4">
|
|
||||||
<Input
|
|
||||||
name="user_id"
|
|
||||||
className="px-3 py-2 border rounded-md w-full"
|
|
||||||
/>
|
|
||||||
</Form.Item>
|
|
||||||
</>
|
|
||||||
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
|
||||||
<Button2 htmlType="submit">Add member</Button2>
|
|
||||||
</div>
|
|
||||||
</Form>
|
|
||||||
</Modal>
|
</Modal>
|
||||||
|
</div>
|
||||||
</Col>
|
</Col>
|
||||||
</Grid>
|
</Grid>
|
||||||
|
<Grid>
|
||||||
|
<Title level={4}>Add SSO</Title>
|
||||||
|
<div className="flex justify-start mb-4">
|
||||||
|
<Button onClick={() => setIsAddSSOModalVisible(true)}>Add SSO</Button>
|
||||||
|
<Modal
|
||||||
|
title="Add SSO"
|
||||||
|
visible={isAddSSOModalVisible}
|
||||||
|
width={800}
|
||||||
|
footer={null}
|
||||||
|
onOk={handleAddSSOOk}
|
||||||
|
onCancel={handleAddSSOCancel}
|
||||||
|
>
|
||||||
|
|
||||||
|
<Form
|
||||||
|
form={form}
|
||||||
|
onFinish={handleShowInstructions}
|
||||||
|
labelCol={{ span: 8 }}
|
||||||
|
wrapperCol={{ span: 16 }}
|
||||||
|
labelAlign="left"
|
||||||
|
>
|
||||||
|
<>
|
||||||
|
<Form.Item
|
||||||
|
label="Admin Email"
|
||||||
|
name="user_email"
|
||||||
|
rules={[{ required: true, message: "Please enter the email of the proxy admin" }]}
|
||||||
|
>
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
label="PROXY BASE URL"
|
||||||
|
name="proxy_base_url"
|
||||||
|
rules={[{ required: true, message: "Please enter the proxy base url" }]}
|
||||||
|
>
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label="GOOGLE CLIENT ID"
|
||||||
|
name="google_client_id"
|
||||||
|
rules={[{ required: true, message: "Please enter the google client id" }]}
|
||||||
|
>
|
||||||
|
<Input.Password />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label="GOOGLE CLIENT SECRET"
|
||||||
|
name="google_client_secret"
|
||||||
|
rules={[{ required: true, message: "Please enter the google client secret" }]}
|
||||||
|
>
|
||||||
|
<Input.Password />
|
||||||
|
</Form.Item>
|
||||||
|
</>
|
||||||
|
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||||
|
<Button2 htmlType="submit">Save</Button2>
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
|
||||||
|
</Modal>
|
||||||
|
<Modal
|
||||||
|
title="SSO Setup Instructions"
|
||||||
|
visible={isInstructionsModalVisible}
|
||||||
|
width={800}
|
||||||
|
footer={null}
|
||||||
|
onOk={handleInstructionsOk}
|
||||||
|
onCancel={handleInstructionsCancel}
|
||||||
|
>
|
||||||
|
<p>Follow these steps to complete the SSO setup:</p>
|
||||||
|
<Text className="mt-2">
|
||||||
|
1. DO NOT Exit this TAB
|
||||||
|
</Text>
|
||||||
|
<Text className="mt-2">
|
||||||
|
2. Open a new tab, visit your proxy base url
|
||||||
|
</Text>
|
||||||
|
<Text className="mt-2">
|
||||||
|
3. Confirm your SSO is configured correctly and you can login on the new Tab
|
||||||
|
</Text>
|
||||||
|
<Text className="mt-2">
|
||||||
|
4. If Step 3 is successful, you can close this tab
|
||||||
|
</Text>
|
||||||
|
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||||
|
<Button2 onClick={handleInstructionsOk}>Done</Button2>
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
<Callout title="Login without SSO" color="teal">
|
||||||
|
If you need to login without sso, you can access <a href= {nonSssoUrl} target="_blank"><b>{nonSssoUrl}</b> </a>
|
||||||
|
</Callout>
|
||||||
|
</Grid>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
|
@ -46,8 +46,8 @@ const Sidebar: React.FC<SidebarProps> = ({
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<Layout style={{ minHeight: "100vh", maxWidth: "100px" }}>
|
<Layout style={{ minHeight: "100vh", maxWidth: "120px" }}>
|
||||||
<Sider width={100}>
|
<Sider width={120}>
|
||||||
<Menu
|
<Menu
|
||||||
mode="inline"
|
mode="inline"
|
||||||
defaultSelectedKeys={defaultSelectedKey ? defaultSelectedKey : ["1"]}
|
defaultSelectedKeys={defaultSelectedKey ? defaultSelectedKey : ["1"]}
|
||||||
|
@ -63,6 +63,11 @@ const Sidebar: React.FC<SidebarProps> = ({
|
||||||
Test Key
|
Test Key
|
||||||
</Text>
|
</Text>
|
||||||
</Menu.Item>
|
</Menu.Item>
|
||||||
|
<Menu.Item key="2" onClick={() => setPage("models")}>
|
||||||
|
<Text>
|
||||||
|
Models
|
||||||
|
</Text>
|
||||||
|
</Menu.Item>
|
||||||
{userRole == "Admin" ? (
|
{userRole == "Admin" ? (
|
||||||
<Menu.Item key="6" onClick={() => setPage("teams")}>
|
<Menu.Item key="6" onClick={() => setPage("teams")}>
|
||||||
<Text>
|
<Text>
|
||||||
|
@ -82,10 +87,10 @@ const Sidebar: React.FC<SidebarProps> = ({
|
||||||
</Text>
|
</Text>
|
||||||
</Menu.Item>
|
</Menu.Item>
|
||||||
) : null}
|
) : null}
|
||||||
<Menu.Item key="2" onClick={() => setPage("models")}>
|
<Menu.Item key="8" onClick={() => setPage("settings")}>
|
||||||
<Text>
|
<Text>
|
||||||
Models
|
Integrations
|
||||||
</Text>
|
</Text>
|
||||||
</Menu.Item>
|
</Menu.Item>
|
||||||
{userRole == "Admin" ? (
|
{userRole == "Admin" ? (
|
||||||
<Menu.Item key="7" onClick={() => setPage("admin-panel")}>
|
<Menu.Item key="7" onClick={() => setPage("admin-panel")}>
|
||||||
|
|
|
@ -13,9 +13,9 @@ import {
|
||||||
Text,
|
Text,
|
||||||
Grid,
|
Grid,
|
||||||
} from "@tremor/react";
|
} from "@tremor/react";
|
||||||
import { TabPanel, TabPanels, TabGroup, TabList, Tab, TextInput } from "@tremor/react";
|
import { TabPanel, TabPanels, TabGroup, TabList, Tab, TextInput, Icon } from "@tremor/react";
|
||||||
import { Select, SelectItem } from "@tremor/react";
|
import { Select, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react";
|
||||||
import { modelInfoCall, userGetRequesedtModelsCall, modelMetricsCall, modelCreateCall, Model } from "./networking";
|
import { modelInfoCall, userGetRequesedtModelsCall, modelMetricsCall, modelCreateCall, Model, modelCostMap, modelDeleteCall } from "./networking";
|
||||||
import { BarChart } from "@tremor/react";
|
import { BarChart } from "@tremor/react";
|
||||||
import {
|
import {
|
||||||
Button as Button2,
|
Button as Button2,
|
||||||
|
@ -33,7 +33,8 @@ import {
|
||||||
import { Badge, BadgeDelta, Button } from "@tremor/react";
|
import { Badge, BadgeDelta, Button } from "@tremor/react";
|
||||||
import RequestAccess from "./request_model_access";
|
import RequestAccess from "./request_model_access";
|
||||||
import { Typography } from "antd";
|
import { Typography } from "antd";
|
||||||
|
import TextArea from "antd/es/input/TextArea";
|
||||||
|
import { InformationCircleIcon, PencilAltIcon, PencilIcon, StatusOnlineIcon, TrashIcon } from "@heroicons/react/outline";
|
||||||
const { Title: Title2, Link } = Typography;
|
const { Title: Title2, Link } = Typography;
|
||||||
|
|
||||||
interface ModelDashboardProps {
|
interface ModelDashboardProps {
|
||||||
|
@ -43,6 +44,26 @@ interface ModelDashboardProps {
|
||||||
userID: string | null;
|
userID: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//["OpenAI", "Azure OpenAI", "Anthropic", "Gemini (Google AI Studio)", "Amazon Bedrock", "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"]
|
||||||
|
|
||||||
|
enum Providers {
|
||||||
|
OpenAI = "OpenAI",
|
||||||
|
Azure = "Azure",
|
||||||
|
Anthropic = "Anthropic",
|
||||||
|
Google_AI_Studio = "Gemini (Google AI Studio)",
|
||||||
|
Bedrock = "Amazon Bedrock",
|
||||||
|
OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider_map: Record <string, string> = {
|
||||||
|
"OpenAI": "openai",
|
||||||
|
"Azure": "azure",
|
||||||
|
"Anthropic": "anthropic",
|
||||||
|
"Google_AI_Studio": "gemini",
|
||||||
|
"Bedrock": "bedrock",
|
||||||
|
"OpenAI_Compatible": "openai"
|
||||||
|
};
|
||||||
|
|
||||||
const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
accessToken,
|
accessToken,
|
||||||
token,
|
token,
|
||||||
|
@ -53,8 +74,12 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
const [modelMetrics, setModelMetrics] = useState<any[]>([]);
|
const [modelMetrics, setModelMetrics] = useState<any[]>([]);
|
||||||
const [pendingRequests, setPendingRequests] = useState<any[]>([]);
|
const [pendingRequests, setPendingRequests] = useState<any[]>([]);
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
|
const [modelMap, setModelMap] = useState<any>(null);
|
||||||
|
|
||||||
const providers = ["OpenAI", "Azure OpenAI", "Anthropic", "Gemini (Google AI Studio)", "Amazon Bedrock", "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"]
|
const [providerModels, setProviderModels] = useState<Array<string>>([]); // Explicitly typing providerModels as a string array
|
||||||
|
|
||||||
|
const providers: Providers[] = [Providers.OpenAI, Providers.Azure, Providers.Anthropic, Providers.Google_AI_Studio, Providers.Bedrock, Providers.OpenAI_Compatible]
|
||||||
|
|
||||||
const [selectedProvider, setSelectedProvider] = useState<String>("OpenAI");
|
const [selectedProvider, setSelectedProvider] = useState<String>("OpenAI");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
@ -95,7 +120,16 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
if (accessToken && token && userRole && userID) {
|
if (accessToken && token && userRole && userID) {
|
||||||
fetchData();
|
fetchData();
|
||||||
}
|
}
|
||||||
}, [accessToken, token, userRole, userID]);
|
|
||||||
|
const fetchModelMap = async () => {
|
||||||
|
const data = await modelCostMap()
|
||||||
|
console.log(`received model cost map data: ${Object.keys(data)}`)
|
||||||
|
setModelMap(data)
|
||||||
|
}
|
||||||
|
if (modelMap == null) {
|
||||||
|
fetchModelMap()
|
||||||
|
}
|
||||||
|
}, [accessToken, token, userRole, userID, modelMap]);
|
||||||
|
|
||||||
if (!modelData) {
|
if (!modelData) {
|
||||||
return <div>Loading...</div>;
|
return <div>Loading...</div>;
|
||||||
|
@ -109,7 +143,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
// loop through model data and edit each row
|
// loop through model data and edit each row
|
||||||
for (let i = 0; i < modelData.data.length; i++) {
|
for (let i = 0; i < modelData.data.length; i++) {
|
||||||
let curr_model = modelData.data[i];
|
let curr_model = modelData.data[i];
|
||||||
let litellm_model_name = curr_model?.litellm_params?.mode
|
let litellm_model_name = curr_model?.litellm_params?.model
|
||||||
let model_info = curr_model?.model_info;
|
let model_info = curr_model?.model_info;
|
||||||
|
|
||||||
let defaultProvider = "openai";
|
let defaultProvider = "openai";
|
||||||
|
@ -117,6 +151,22 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
let input_cost = "Undefined";
|
let input_cost = "Undefined";
|
||||||
let output_cost = "Undefined";
|
let output_cost = "Undefined";
|
||||||
let max_tokens = "Undefined";
|
let max_tokens = "Undefined";
|
||||||
|
let cleanedLitellmParams = {};
|
||||||
|
|
||||||
|
const getProviderFromModel = (model: string) => {
|
||||||
|
/**
|
||||||
|
* Use model map
|
||||||
|
* - check if model in model map
|
||||||
|
* - return it's litellm_provider, if so
|
||||||
|
*/
|
||||||
|
console.log(`GET PROVIDER CALLED! - ${modelMap}`)
|
||||||
|
if (modelMap !== null && modelMap !== undefined) {
|
||||||
|
if (typeof modelMap == "object" && model in modelMap) {
|
||||||
|
return modelMap[model]["litellm_provider"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "openai"
|
||||||
|
}
|
||||||
|
|
||||||
// Check if litellm_model_name is null or undefined
|
// Check if litellm_model_name is null or undefined
|
||||||
if (litellm_model_name) {
|
if (litellm_model_name) {
|
||||||
|
@ -127,10 +177,10 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
let firstElement = splitModel[0];
|
let firstElement = splitModel[0];
|
||||||
|
|
||||||
// If there is only one element, default provider to openai
|
// If there is only one element, default provider to openai
|
||||||
provider = splitModel.length === 1 ? defaultProvider : firstElement;
|
provider = splitModel.length === 1 ? getProviderFromModel(litellm_model_name) : firstElement;
|
||||||
} else {
|
} else {
|
||||||
// litellm_model_name is null or undefined, default provider to openai
|
// litellm_model_name is null or undefined, default provider to openai
|
||||||
provider = defaultProvider;
|
provider = "openai"
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_info) {
|
if (model_info) {
|
||||||
|
@ -138,11 +188,22 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
output_cost = model_info?.output_cost_per_token;
|
output_cost = model_info?.output_cost_per_token;
|
||||||
max_tokens = model_info?.max_tokens;
|
max_tokens = model_info?.max_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// let cleanedLitellmParams == litellm_params without model, api_base
|
||||||
|
if (curr_model?.litellm_params) {
|
||||||
|
cleanedLitellmParams = Object.fromEntries(
|
||||||
|
Object.entries(curr_model?.litellm_params).filter(
|
||||||
|
([key]) => key !== "model" && key !== "api_base"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
modelData.data[i].provider = provider;
|
modelData.data[i].provider = provider;
|
||||||
modelData.data[i].input_cost = input_cost;
|
modelData.data[i].input_cost = input_cost;
|
||||||
modelData.data[i].output_cost = output_cost;
|
modelData.data[i].output_cost = output_cost;
|
||||||
modelData.data[i].max_tokens = max_tokens;
|
modelData.data[i].max_tokens = max_tokens;
|
||||||
modelData.data[i].api_base = curr_model?.litellm_params?.api_base;
|
modelData.data[i].api_base = curr_model?.litellm_params?.api_base;
|
||||||
|
modelData.data[i].cleanedLitellmParams = cleanedLitellmParams;
|
||||||
|
|
||||||
all_models_on_proxy.push(curr_model.model_name);
|
all_models_on_proxy.push(curr_model.model_name);
|
||||||
|
|
||||||
|
@ -162,43 +223,115 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleDelete = async (model_id: string) => {
|
||||||
|
await modelDeleteCall(accessToken, model_id)
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
const setProviderModelsFn = (provider: string) => {
|
||||||
|
console.log(`received provider string: ${provider}`)
|
||||||
|
const providerEnumValue = Providers[provider as keyof typeof Providers];
|
||||||
|
console.log(`received providerEnumValue: ${providerEnumValue}`)
|
||||||
|
const mappingResult = provider_map[providerEnumValue]; // Get the corresponding value from the mapping
|
||||||
|
console.log(`mappingResult: ${mappingResult}`)
|
||||||
|
let _providerModels: Array<string> = []
|
||||||
|
if (typeof modelMap === 'object') {
|
||||||
|
Object.entries(modelMap).forEach(([key, value]) => {
|
||||||
|
if (value !== null && typeof value === 'object' && "litellm_provider" in value && value["litellm_provider"] === mappingResult) {
|
||||||
|
_providerModels.push(key);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
setProviderModels(_providerModels)
|
||||||
|
console.log(`providerModels: ${providerModels}`);
|
||||||
|
}
|
||||||
|
|
||||||
const handleSubmit = async (formValues: Record<string, any>) => {
|
const handleSubmit = async (formValues: Record<string, any>) => {
|
||||||
const litellmParamsObj: Record<string, any> = {};
|
try {
|
||||||
const modelInfoObj: Record<string, any> = {};
|
/**
|
||||||
let modelName: string = "";
|
* For multiple litellm model names - create a separate deployment for each
|
||||||
// Iterate through the key-value pairs in formValues
|
* - get the list
|
||||||
for (const [key, value] of Object.entries(formValues)) {
|
* - iterate through it
|
||||||
if (key == "model_name") {
|
* - create a new deployment for each
|
||||||
modelName = value
|
*/
|
||||||
|
|
||||||
|
// get the list of deployments
|
||||||
|
let deployments: Array<string> = Object.values(formValues["model"])
|
||||||
|
console.log(`received deployments: ${deployments}`)
|
||||||
|
console.log(`received type of deployments: ${typeof deployments}`)
|
||||||
|
deployments.forEach(async (litellm_model) => {
|
||||||
|
console.log(`litellm_model: ${litellm_model}`)
|
||||||
|
const litellmParamsObj: Record<string, any> = {};
|
||||||
|
const modelInfoObj: Record<string, any> = {};
|
||||||
|
// Iterate through the key-value pairs in formValues
|
||||||
|
litellmParamsObj["model"] = litellm_model
|
||||||
|
let modelName: string = "";
|
||||||
|
for (const [key, value] of Object.entries(formValues)) {
|
||||||
|
if (key == "model_name") {
|
||||||
|
modelName = modelName + value
|
||||||
|
}
|
||||||
|
else if (key == "custom_llm_provider") {
|
||||||
|
// const providerEnumValue = Providers[value as keyof typeof Providers];
|
||||||
|
// const mappingResult = provider_map[providerEnumValue]; // Get the corresponding value from the mapping
|
||||||
|
// modelName = mappingResult + "/" + modelName
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
else if (key == "model") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if key is "base_model"
|
||||||
|
else if (key === "base_model") {
|
||||||
|
// Add key-value pair to model_info dictionary
|
||||||
|
modelInfoObj[key] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (key == "litellm_extra_params") {
|
||||||
|
console.log("litellm_extra_params:", value);
|
||||||
|
let litellmExtraParams = {};
|
||||||
|
if (value && value != undefined) {
|
||||||
|
try {
|
||||||
|
litellmExtraParams = JSON.parse(value);
|
||||||
|
}
|
||||||
|
catch (error) {
|
||||||
|
message.error("Failed to parse LiteLLM Extra Params: " + error);
|
||||||
|
throw new Error("Failed to parse litellm_extra_params: " + error);
|
||||||
|
}
|
||||||
|
for (const [key, value] of Object.entries(litellmExtraParams)) {
|
||||||
|
litellmParamsObj[key] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if key is any of the specified API related keys
|
||||||
|
else {
|
||||||
|
// Add key-value pair to litellm_params dictionary
|
||||||
|
litellmParamsObj[key] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const new_model: Model = {
|
||||||
|
"model_name": modelName,
|
||||||
|
"litellm_params": litellmParamsObj,
|
||||||
|
"model_info": modelInfoObj
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
const response: any = await modelCreateCall(
|
||||||
|
accessToken,
|
||||||
|
new_model
|
||||||
|
);
|
||||||
|
|
||||||
|
console.log(`response for model create call: ${response["data"]}`);
|
||||||
|
});
|
||||||
|
|
||||||
|
form.resetFields();
|
||||||
|
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
message.error("Failed to create model: " + error);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if key is any of the specified API related keys
|
|
||||||
if (key === "api_key" || key === "model" || key === "api_base" || key === "api_version" || key.startsWith("aws_")) {
|
|
||||||
// Add key-value pair to litellm_params dictionary
|
|
||||||
litellmParamsObj[key] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if key is "base_model"
|
|
||||||
if (key === "base_model") {
|
|
||||||
// Add key-value pair to model_info dictionary
|
|
||||||
modelInfoObj[key] = value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const new_model: Model = {
|
|
||||||
"model_name": modelName,
|
|
||||||
"litellm_params": litellmParamsObj,
|
|
||||||
"model_info": modelInfoObj
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const response: any = await modelCreateCall(
|
|
||||||
accessToken,
|
|
||||||
new_model
|
|
||||||
);
|
|
||||||
|
|
||||||
console.log(`response for model create call: ${response["data"]}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleOk = () => {
|
const handleOk = () => {
|
||||||
|
@ -206,7 +339,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
.validateFields()
|
.validateFields()
|
||||||
.then((values) => {
|
.then((values) => {
|
||||||
handleSubmit(values);
|
handleSubmit(values);
|
||||||
form.resetFields();
|
// form.resetFields();
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error("Validation failed:", error);
|
console.error("Validation failed:", error);
|
||||||
|
@ -214,7 +347,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
};
|
};
|
||||||
|
|
||||||
console.log(`selectedProvider: ${selectedProvider}`)
|
console.log(`selectedProvider: ${selectedProvider}`)
|
||||||
|
console.log(`providerModels.length: ${providerModels.length}`)
|
||||||
return (
|
return (
|
||||||
<div style={{ width: "100%", height: "100%"}}>
|
<div style={{ width: "100%", height: "100%"}}>
|
||||||
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
|
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
|
||||||
|
@ -244,7 +377,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<TableHeaderCell>
|
<TableHeaderCell>
|
||||||
Access
|
Extra litellm Params
|
||||||
</TableHeaderCell>
|
</TableHeaderCell>
|
||||||
<TableHeaderCell>Input Price per token ($)</TableHeaderCell>
|
<TableHeaderCell>Input Price per token ($)</TableHeaderCell>
|
||||||
<TableHeaderCell>Output Price per token ($)</TableHeaderCell>
|
<TableHeaderCell>Output Price per token ($)</TableHeaderCell>
|
||||||
|
@ -252,8 +385,8 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
</TableRow>
|
</TableRow>
|
||||||
</TableHead>
|
</TableHead>
|
||||||
<TableBody>
|
<TableBody>
|
||||||
{modelData.data.map((model: any) => (
|
{modelData.data.map((model: any, index: number) => (
|
||||||
<TableRow key={model.model_name}>
|
<TableRow key={index}>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<Text>{model.model_name}</Text>
|
<Text>{model.model_name}</Text>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
|
@ -265,20 +398,15 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
}
|
}
|
||||||
|
|
||||||
<TableCell>
|
<TableCell>
|
||||||
{model.user_access ? (
|
<pre>
|
||||||
<Badge color={"green"}>Yes</Badge>
|
{JSON.stringify(model.cleanedLitellmParams, null, 2)}
|
||||||
) : (
|
</pre>
|
||||||
<RequestAccess
|
|
||||||
userModels={all_models_on_proxy}
|
|
||||||
accessToken={accessToken}
|
|
||||||
userID={userID}
|
|
||||||
></RequestAccess>
|
|
||||||
)}
|
|
||||||
</TableCell>
|
</TableCell>
|
||||||
|
|
||||||
<TableCell>{model.input_cost}</TableCell>
|
<TableCell>{model.input_cost}</TableCell>
|
||||||
<TableCell>{model.output_cost}</TableCell>
|
<TableCell>{model.output_cost}</TableCell>
|
||||||
<TableCell>{model.max_tokens}</TableCell>
|
<TableCell>{model.max_tokens}</TableCell>
|
||||||
|
<TableCell><Icon icon={TrashIcon} size="sm" onClick={() => handleDelete(model.model_info.id)}/></TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
))}
|
))}
|
||||||
</TableBody>
|
</TableBody>
|
||||||
|
@ -331,6 +459,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
key={index}
|
key={index}
|
||||||
value={provider}
|
value={provider}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
|
setProviderModelsFn(provider);
|
||||||
setSelectedProvider(provider);
|
setSelectedProvider(provider);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
@ -344,18 +473,28 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Row>
|
<Row>
|
||||||
<Col span={10}></Col>
|
<Col span={10}></Col>
|
||||||
<Col span={10}><Text className="mb-3 mt-1">Model name your users will pass in. Also used for <Link href="https://docs.litellm.ai/docs/proxy/reliability#step-1---set-deployments-on-config" target="_blank">loadbalancing.</Link></Text></Col>
|
<Col span={10}><Text className="mb-3 mt-1">Model name your users will pass in.</Text></Col>
|
||||||
</Row>
|
</Row>
|
||||||
<Form.Item rules={[{ required: true, message: 'Required' }]} label="LiteLLM Model Name" name="model" tooltip="Actual model name used for making litellm.completion() call." className="mb-0">
|
<Form.Item rules={[{ required: true, message: 'Required' }]} label="LiteLLM Model Name(s)" name="model" tooltip="Actual model name used for making litellm.completion() call." className="mb-0">
|
||||||
<TextInput placeholder="gpt-3.5-turbo-0125"/>
|
{selectedProvider === Providers.Azure ? (
|
||||||
|
<TextInput placeholder="Enter model name" />
|
||||||
|
) : providerModels.length > 0 ? (
|
||||||
|
<MultiSelect value={providerModels}>
|
||||||
|
{providerModels.map((model, index) => (
|
||||||
|
<MultiSelectItem key={index} value={model}>
|
||||||
|
{model}
|
||||||
|
</MultiSelectItem>
|
||||||
|
))}
|
||||||
|
</MultiSelect>
|
||||||
|
) : (
|
||||||
|
<TextInput placeholder="gpt-3.5-turbo-0125" />
|
||||||
|
)}
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Row>
|
<Row>
|
||||||
<Col span={10}></Col>
|
<Col span={10}></Col>
|
||||||
<Col span={10}><Text className="mb-3 mt-1">Actual model name used for making <Link href="https://docs.litellm.ai/docs/providers" target="_blank">litellm.completion() call</Link></Text></Col>
|
<Col span={10}><Text className="mb-3 mt-1">Actual model name used for making<Link href="https://docs.litellm.ai/docs/providers" target="_blank">litellm.completion() call</Link>.We'll<Link href="https://docs.litellm.ai/docs/proxy/reliability#step-1---set-deployments-on-config" target="_blank">loadbalance</Link> models with the same 'public name'</Text></Col></Row>
|
||||||
</Row>
|
|
||||||
|
|
||||||
{
|
{
|
||||||
selectedProvider != "Amazon Bedrock" && <Form.Item
|
selectedProvider != Providers.Bedrock && <Form.Item
|
||||||
rules={[{ required: true, message: 'Required' }]}
|
rules={[{ required: true, message: 'Required' }]}
|
||||||
label="API Key"
|
label="API Key"
|
||||||
name="api_key"
|
name="api_key"
|
||||||
|
@ -364,7 +503,15 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
selectedProvider == "Azure OpenAI" && <Form.Item
|
selectedProvider == Providers.OpenAI && <Form.Item
|
||||||
|
label="Organization ID"
|
||||||
|
name="organization_id"
|
||||||
|
>
|
||||||
|
<TextInput placeholder="[OPTIONAL] my-unique-org"/>
|
||||||
|
</Form.Item>
|
||||||
|
}
|
||||||
|
{
|
||||||
|
(selectedProvider == Providers.Azure || selectedProvider == Providers.OpenAI_Compatible) && <Form.Item
|
||||||
rules={[{ required: true, message: 'Required' }]}
|
rules={[{ required: true, message: 'Required' }]}
|
||||||
label="API Base"
|
label="API Base"
|
||||||
name="api_base"
|
name="api_base"
|
||||||
|
@ -373,7 +520,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
selectedProvider == "Azure OpenAI" && <Form.Item
|
selectedProvider == Providers.Azure && <Form.Item
|
||||||
rules={[{ required: true, message: 'Required' }]}
|
rules={[{ required: true, message: 'Required' }]}
|
||||||
label="API Version"
|
label="API Version"
|
||||||
name="api_version"
|
name="api_version"
|
||||||
|
@ -382,7 +529,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
selectedProvider == "Azure OpenAI" && <Form.Item
|
selectedProvider == Providers.Azure && <Form.Item
|
||||||
label="Base Model"
|
label="Base Model"
|
||||||
name="base_model"
|
name="base_model"
|
||||||
>
|
>
|
||||||
|
@ -391,7 +538,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
selectedProvider == "Amazon Bedrock" && <Form.Item
|
selectedProvider == Providers.Bedrock && <Form.Item
|
||||||
rules={[{ required: true, message: 'Required' }]}
|
rules={[{ required: true, message: 'Required' }]}
|
||||||
label="AWS Access Key ID"
|
label="AWS Access Key ID"
|
||||||
name="aws_access_key_id"
|
name="aws_access_key_id"
|
||||||
|
@ -401,7 +548,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
selectedProvider == "Amazon Bedrock" && <Form.Item
|
selectedProvider == Providers.Bedrock && <Form.Item
|
||||||
rules={[{ required: true, message: 'Required' }]}
|
rules={[{ required: true, message: 'Required' }]}
|
||||||
label="AWS Secret Access Key"
|
label="AWS Secret Access Key"
|
||||||
name="aws_secret_access_key"
|
name="aws_secret_access_key"
|
||||||
|
@ -411,7 +558,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
selectedProvider == "Amazon Bedrock" && <Form.Item
|
selectedProvider == Providers.Bedrock && <Form.Item
|
||||||
rules={[{ required: true, message: 'Required' }]}
|
rules={[{ required: true, message: 'Required' }]}
|
||||||
label="AWS Region Name"
|
label="AWS Region Name"
|
||||||
name="aws_region_name"
|
name="aws_region_name"
|
||||||
|
@ -420,6 +567,22 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
<TextInput placeholder="us-east-1"/>
|
<TextInput placeholder="us-east-1"/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
}
|
}
|
||||||
|
<Form.Item label="LiteLLM Params" name="litellm_extra_params" tooltip="Optional litellm params used for making a litellm.completion() call." className="mb-0">
|
||||||
|
<TextArea
|
||||||
|
rows={4}
|
||||||
|
placeholder='{
|
||||||
|
"rpm": 100,
|
||||||
|
"timeout": 0,
|
||||||
|
"stream_timeout": 0
|
||||||
|
}'
|
||||||
|
/>
|
||||||
|
|
||||||
|
</Form.Item>
|
||||||
|
<Row>
|
||||||
|
<Col span={10}></Col>
|
||||||
|
<Col span={10}><Text className="mb-3 mt-1">Pass JSON of litellm supported params <Link href="https://docs.litellm.ai/docs/completion/input" target="_blank">litellm.completion() call</Link></Text></Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
</>
|
</>
|
||||||
<div style={{ textAlign: "center", marginTop: "10px" }}>
|
<div style={{ textAlign: "center", marginTop: "10px" }}>
|
||||||
<Button2 htmlType="submit">Add Model</Button2>
|
<Button2 htmlType="submit">Add Model</Button2>
|
||||||
|
|
|
@ -32,6 +32,7 @@ const Navbar: React.FC<NavbarProps> = ({
|
||||||
}) => {
|
}) => {
|
||||||
console.log("User ID:", userID);
|
console.log("User ID:", userID);
|
||||||
console.log("userEmail:", userEmail);
|
console.log("userEmail:", userEmail);
|
||||||
|
console.log("showSSOBanner:", showSSOBanner);
|
||||||
|
|
||||||
// const userColors = require('./ui_colors.json') || {};
|
// const userColors = require('./ui_colors.json') || {};
|
||||||
const isLocal = process.env.NODE_ENV === "development";
|
const isLocal = process.env.NODE_ENV === "development";
|
||||||
|
@ -67,13 +68,25 @@ const Navbar: React.FC<NavbarProps> = ({
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="text-right mx-4 my-2 absolute top-0 right-0 flex items-center justify-end space-x-2">
|
<div className="text-right mx-4 my-2 absolute top-0 right-0 flex items-center justify-end space-x-2">
|
||||||
{showSSOBanner ? (
|
{showSSOBanner ? (
|
||||||
|
|
||||||
|
<div style={{
|
||||||
|
// border: '1px solid #391085',
|
||||||
|
padding: '6px',
|
||||||
|
borderRadius: '8px', // Added border-radius property
|
||||||
|
}}
|
||||||
|
>
|
||||||
<a
|
<a
|
||||||
href="https://docs.litellm.ai/docs/proxy/ui#setup-ssoauth-for-ui"
|
href="https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
className="mr-2"
|
style={{
|
||||||
|
"fontSize": "14px",
|
||||||
|
"textDecoration": "underline"
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
|
Request hosted proxy
|
||||||
</a>
|
</a>
|
||||||
|
</div>
|
||||||
) : null}
|
) : null}
|
||||||
|
|
||||||
<div style={{
|
<div style={{
|
||||||
|
|
|
@ -12,6 +12,18 @@ export interface Model {
|
||||||
model_info: Object | null;
|
model_info: Object | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const modelCostMap = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch('https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json');
|
||||||
|
const jsonData = await response.json();
|
||||||
|
console.log(`received data: ${jsonData}`)
|
||||||
|
return jsonData
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to get model cost map:", error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export const modelCreateCall = async (
|
export const modelCreateCall = async (
|
||||||
accessToken: string,
|
accessToken: string,
|
||||||
formValues: Model
|
formValues: Model
|
||||||
|
@ -38,7 +50,42 @@ export const modelCreateCall = async (
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
console.log("API Response:", data);
|
console.log("API Response:", data);
|
||||||
message.success("Model created successfully. Wait 60s and refresh.")
|
message.success("Model created successfully. Wait 60s and refresh on 'All Models' page");
|
||||||
|
return data;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to create key:", error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const modelDeleteCall = async (
|
||||||
|
accessToken: string,
|
||||||
|
model_id: string,
|
||||||
|
) => {
|
||||||
|
console.log(`model_id in model delete call: ${model_id}`)
|
||||||
|
try {
|
||||||
|
const url = proxyBaseUrl ? `${proxyBaseUrl}/model/delete` : `/model/delete`;
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${accessToken}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
"id": model_id,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.text();
|
||||||
|
message.error("Failed to create key: " + errorData);
|
||||||
|
console.error("Error response from the server:", errorData);
|
||||||
|
throw new Error("Network response was not ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log("API Response:", data);
|
||||||
|
message.success("Model deleted successfully. Restart server to see this.");
|
||||||
return data;
|
return data;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to create key:", error);
|
console.error("Failed to create key:", error);
|
||||||
|
@ -339,6 +386,7 @@ export const modelInfoCall = async (
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
console.log("modelInfoCall:", data);
|
||||||
//message.info("Received model data");
|
//message.info("Received model data");
|
||||||
return data;
|
return data;
|
||||||
// Handle success - you might want to update some state or UI based on the created key
|
// Handle success - you might want to update some state or UI based on the created key
|
||||||
|
@ -1008,22 +1056,25 @@ export const teamMemberAddCall = async (
|
||||||
|
|
||||||
export const userUpdateUserCall = async (
|
export const userUpdateUserCall = async (
|
||||||
accessToken: string,
|
accessToken: string,
|
||||||
formValues: any // Assuming formValues is an object
|
formValues: any, // Assuming formValues is an object
|
||||||
|
userRole: string | null
|
||||||
) => {
|
) => {
|
||||||
try {
|
try {
|
||||||
console.log("Form Values in userUpdateUserCall:", formValues); // Log the form values before making the API call
|
console.log("Form Values in userUpdateUserCall:", formValues); // Log the form values before making the API call
|
||||||
|
|
||||||
const url = proxyBaseUrl ? `${proxyBaseUrl}/user/update` : `/user/update`;
|
const url = proxyBaseUrl ? `${proxyBaseUrl}/user/update` : `/user/update`;
|
||||||
|
let response_body = {...formValues};
|
||||||
|
if (userRole !== null) {
|
||||||
|
response_body["user_role"] = userRole;
|
||||||
|
}
|
||||||
|
response_body = JSON.stringify(response_body);
|
||||||
const response = await fetch(url, {
|
const response = await fetch(url, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
Authorization: `Bearer ${accessToken}`,
|
Authorization: `Bearer ${accessToken}`,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: response_body,
|
||||||
user_role: "proxy_admin_viewer",
|
|
||||||
...formValues, // Include formValues in the request body
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
@ -1119,3 +1170,85 @@ export const slackBudgetAlertsHealthCheck = async (accessToken: String) => {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
export const getCallbacksCall = async (
|
||||||
|
accessToken: String,
|
||||||
|
userID: String,
|
||||||
|
userRole: String
|
||||||
|
) => {
|
||||||
|
/**
|
||||||
|
* Get all the models user has access to
|
||||||
|
*/
|
||||||
|
try {
|
||||||
|
let url = proxyBaseUrl ? `${proxyBaseUrl}/get/config/callbacks` : `/get/config/callbacks`;
|
||||||
|
|
||||||
|
//message.info("Requesting model data");
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: "GET",
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${accessToken}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.text();
|
||||||
|
message.error(errorData);
|
||||||
|
throw new Error("Network response was not ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
//message.info("Received model data");
|
||||||
|
return data;
|
||||||
|
// Handle success - you might want to update some state or UI based on the created key
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to get callbacks:", error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
export const setCallbacksCall = async (
|
||||||
|
accessToken: String,
|
||||||
|
formValues: Record<string, any>
|
||||||
|
) => {
|
||||||
|
/**
|
||||||
|
* Set callbacks on proxy
|
||||||
|
*/
|
||||||
|
try {
|
||||||
|
let url = proxyBaseUrl ? `${proxyBaseUrl}/config/update` : `/config/update`;
|
||||||
|
|
||||||
|
//message.info("Requesting model data");
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${accessToken}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
...formValues, // Include formValues in the request body
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.text();
|
||||||
|
message.error(errorData);
|
||||||
|
throw new Error("Network response was not ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
//message.info("Received model data");
|
||||||
|
return data;
|
||||||
|
// Handle success - you might want to update some state or UI based on the created key
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to set callbacks:", error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
211
ui/litellm-dashboard/src/components/settings.tsx
Normal file
211
ui/litellm-dashboard/src/components/settings.tsx
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
import React, { useState, useEffect } from "react";
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
Title,
|
||||||
|
Subtitle,
|
||||||
|
Table,
|
||||||
|
TableHead,
|
||||||
|
TableRow,
|
||||||
|
Badge,
|
||||||
|
TableHeaderCell,
|
||||||
|
TableCell,
|
||||||
|
TableBody,
|
||||||
|
Metric,
|
||||||
|
Text,
|
||||||
|
Grid,
|
||||||
|
Button,
|
||||||
|
Col,
|
||||||
|
} from "@tremor/react";
|
||||||
|
import { getCallbacksCall, setCallbacksCall } from "./networking";
|
||||||
|
import { Modal, Form, Input, Select, Button as Button2 } from "antd";
|
||||||
|
import StaticGenerationSearchParamsBailoutProvider from "next/dist/client/components/static-generation-searchparams-bailout-provider";
|
||||||
|
|
||||||
|
interface SettingsPageProps {
|
||||||
|
accessToken: string | null;
|
||||||
|
userRole: string | null;
|
||||||
|
userID: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Settings: React.FC<SettingsPageProps> = ({
|
||||||
|
accessToken,
|
||||||
|
userRole,
|
||||||
|
userID,
|
||||||
|
}) => {
|
||||||
|
const [callbacks, setCallbacks] = useState<string[]>([]);
|
||||||
|
const [isModalVisible, setIsModalVisible] = useState(false);
|
||||||
|
const [form] = Form.useForm();
|
||||||
|
const [selectedCallback, setSelectedCallback] = useState<string | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!accessToken || !userRole || !userID) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
getCallbacksCall(accessToken, userID, userRole).then((data) => {
|
||||||
|
console.log("callbacks", data);
|
||||||
|
let callbacks_data = data.data;
|
||||||
|
let callback_names = callbacks_data.success_callback; // ["callback1", "callback2"]
|
||||||
|
setCallbacks(callback_names);
|
||||||
|
});
|
||||||
|
}, [accessToken, userRole, userID]);
|
||||||
|
|
||||||
|
const handleAddCallback = () => {
|
||||||
|
console.log("Add callback clicked");
|
||||||
|
setIsModalVisible(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCancel = () => {
|
||||||
|
setIsModalVisible(false);
|
||||||
|
form.resetFields();
|
||||||
|
setSelectedCallback(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleOk = () => {
|
||||||
|
if (!accessToken) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Handle form submission
|
||||||
|
form.validateFields().then((values) => {
|
||||||
|
// Call API to add the callback
|
||||||
|
console.log("Form values:", values);
|
||||||
|
let payload;
|
||||||
|
if (values.callback === 'langfuse') {
|
||||||
|
payload = {
|
||||||
|
environment_variables: {
|
||||||
|
LANGFUSE_PUBLIC_KEY: values.langfusePublicKey,
|
||||||
|
LANGFUSE_SECRET_KEY: values.langfusePrivateKey
|
||||||
|
},
|
||||||
|
litellm_settings: {
|
||||||
|
success_callback: [values.callback]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
setCallbacksCall(accessToken, payload);
|
||||||
|
|
||||||
|
// add langfuse to callbacks
|
||||||
|
setCallbacks(callbacks ? [...callbacks, values.callback] : [values.callback]);
|
||||||
|
} else if (values.callback === 'slack') {
|
||||||
|
payload = {
|
||||||
|
general_settings: {
|
||||||
|
alerting: ["slack"],
|
||||||
|
alerting_threshold: 300
|
||||||
|
},
|
||||||
|
environment_variables: {
|
||||||
|
SLACK_WEBHOOK_URL: values.slackWebhookUrl
|
||||||
|
}
|
||||||
|
};
|
||||||
|
setCallbacksCall(accessToken, payload);
|
||||||
|
|
||||||
|
// add slack to callbacks
|
||||||
|
setCallbacks(callbacks ? [...callbacks, values.callback] : [values.callback]);
|
||||||
|
} else {
|
||||||
|
payload = {
|
||||||
|
error: 'Invalid callback value'
|
||||||
|
};
|
||||||
|
}
|
||||||
|
setIsModalVisible(false);
|
||||||
|
form.resetFields();
|
||||||
|
setSelectedCallback(null);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCallbackChange = (value: string) => {
|
||||||
|
setSelectedCallback(value);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full mx-4">
|
||||||
|
<Grid numItems={1} className="gap-2 p-8 h-[75vh] w-full mt-2">
|
||||||
|
<Card className="h-[15vh]">
|
||||||
|
<Grid numItems={2} className="mt-2">
|
||||||
|
<Col>
|
||||||
|
<Title>Logging Callbacks</Title>
|
||||||
|
</Col>
|
||||||
|
<Col>
|
||||||
|
<div>
|
||||||
|
{!callbacks ? (
|
||||||
|
<Badge color={"red"}>None</Badge>
|
||||||
|
) : callbacks.length === 0 ? (
|
||||||
|
<Badge>None</Badge>
|
||||||
|
) : (
|
||||||
|
callbacks.map((callback, index) => (
|
||||||
|
<Badge key={index} color={"sky"}>
|
||||||
|
{callback}
|
||||||
|
</Badge>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</Col>
|
||||||
|
</Grid>
|
||||||
|
<Col>
|
||||||
|
<Button size="xs" className="mt-2" onClick={handleAddCallback}>
|
||||||
|
Add Callback
|
||||||
|
</Button>
|
||||||
|
</Col>
|
||||||
|
</Card>
|
||||||
|
</Grid>
|
||||||
|
|
||||||
|
<Modal
|
||||||
|
title="Add Callback"
|
||||||
|
visible={isModalVisible}
|
||||||
|
onOk={handleOk}
|
||||||
|
width={800}
|
||||||
|
onCancel={handleCancel}
|
||||||
|
footer={null}
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical" onFinish={handleOk}>
|
||||||
|
<Form.Item
|
||||||
|
label="Callback"
|
||||||
|
name="callback"
|
||||||
|
rules={[{ required: true, message: "Please select a callback" }]}
|
||||||
|
>
|
||||||
|
<Select onChange={handleCallbackChange}>
|
||||||
|
<Select.Option value="langfuse">langfuse</Select.Option>
|
||||||
|
<Select.Option value="slack">slack alerting</Select.Option>
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
{selectedCallback === 'langfuse' && (
|
||||||
|
<>
|
||||||
|
<Form.Item
|
||||||
|
label="LANGFUSE_PUBLIC_KEY"
|
||||||
|
name="langfusePublicKey"
|
||||||
|
rules={[
|
||||||
|
{ required: true, message: "Please enter the public key" },
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<Input.Password />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label="LANGFUSE_PRIVATE_KEY"
|
||||||
|
name="langfusePrivateKey"
|
||||||
|
rules={[
|
||||||
|
{ required: true, message: "Please enter the private key" },
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<Input.Password />
|
||||||
|
</Form.Item>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{selectedCallback === 'slack' && (
|
||||||
|
<Form.Item
|
||||||
|
label="SLACK_WEBHOOK_URL"
|
||||||
|
name="slackWebhookUrl"
|
||||||
|
rules={[
|
||||||
|
{ required: true, message: "Please enter the Slack webhook URL" },
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||||
|
<Button2 htmlType="submit">Save</Button2>
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Settings;
|
|
@ -274,6 +274,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
||||||
userID={userID}
|
userID={userID}
|
||||||
userRole={userRole}
|
userRole={userRole}
|
||||||
accessToken={accessToken}
|
accessToken={accessToken}
|
||||||
|
userSpend={null}
|
||||||
/>
|
/>
|
||||||
<TabGroup>
|
<TabGroup>
|
||||||
<TabList className="mt-2">
|
<TabList className="mt-2">
|
||||||
|
|
|
@ -18,6 +18,7 @@ type UserSpendData = {
|
||||||
max_budget?: number | null;
|
max_budget?: number | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
interface UserDashboardProps {
|
interface UserDashboardProps {
|
||||||
userID: string | null;
|
userID: string | null;
|
||||||
userRole: string | null;
|
userRole: string | null;
|
||||||
|
@ -52,6 +53,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
|
|
||||||
const token = searchParams.get("token");
|
const token = searchParams.get("token");
|
||||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||||
|
const [teamSpend, setTeamSpend] = useState<number | null>(null);
|
||||||
const [userModels, setUserModels] = useState<string[]>([]);
|
const [userModels, setUserModels] = useState<string[]>([]);
|
||||||
const [selectedTeam, setSelectedTeam] = useState<any | null>(
|
const [selectedTeam, setSelectedTeam] = useState<any | null>(
|
||||||
teams ? teams[0] : null
|
teams ? teams[0] : null
|
||||||
|
@ -174,8 +176,29 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
fetchData();
|
fetchData();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}, [userID, token, accessToken, keys, userRole]);
|
}, [userID, token, accessToken, keys, userRole]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// This code will run every time selectedTeam changes
|
||||||
|
if (keys !== null && selectedTeam !== null && selectedTeam !== undefined) {
|
||||||
|
let sum = 0;
|
||||||
|
for (const key of keys) {
|
||||||
|
if (selectedTeam.hasOwnProperty('team_id') && key.team_id !== null && key.team_id === selectedTeam.team_id) {
|
||||||
|
sum += key.spend;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setTeamSpend(sum);
|
||||||
|
} else if (keys !== null) {
|
||||||
|
// sum the keys which don't have team-id set (default team)
|
||||||
|
let sum = 0
|
||||||
|
for (const key of keys) {
|
||||||
|
sum += key.spend;
|
||||||
|
}
|
||||||
|
setTeamSpend(sum);
|
||||||
|
}
|
||||||
|
}, [selectedTeam]);
|
||||||
|
|
||||||
if (userID == null || token == null) {
|
if (userID == null || token == null) {
|
||||||
// Now you can construct the full URL
|
// Now you can construct the full URL
|
||||||
const url = proxyBaseUrl
|
const url = proxyBaseUrl
|
||||||
|
@ -204,7 +227,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log("inside user dashboard, selected team", selectedTeam);
|
console.log("inside user dashboard, selected team", selectedTeam);
|
||||||
|
console.log(`teamSpend: ${teamSpend}`)
|
||||||
return (
|
return (
|
||||||
<div className="w-full mx-4">
|
<div className="w-full mx-4">
|
||||||
<Grid numItems={1} className="gap-2 p-8 h-[75vh] w-full mt-2">
|
<Grid numItems={1} className="gap-2 p-8 h-[75vh] w-full mt-2">
|
||||||
|
@ -213,6 +236,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
||||||
userID={userID}
|
userID={userID}
|
||||||
userRole={userRole}
|
userRole={userRole}
|
||||||
accessToken={accessToken}
|
accessToken={accessToken}
|
||||||
|
userSpend={teamSpend}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<ViewKeyTable
|
<ViewKeyTable
|
||||||
|
|
|
@ -31,16 +31,18 @@ interface ViewUserSpendProps {
|
||||||
userID: string | null;
|
userID: string | null;
|
||||||
userRole: string | null;
|
userRole: string | null;
|
||||||
accessToken: string | null;
|
accessToken: string | null;
|
||||||
|
userSpend: number | null;
|
||||||
}
|
}
|
||||||
const ViewUserSpend: React.FC<ViewUserSpendProps> = ({ userID, userRole, accessToken }) => {
|
const ViewUserSpend: React.FC<ViewUserSpendProps> = ({ userID, userRole, accessToken, userSpend }) => {
|
||||||
const [spend, setSpend] = useState(0.0);
|
console.log(`userSpend: ${userSpend}`)
|
||||||
|
let [spend, setSpend] = useState(userSpend !== null ? userSpend : 0.0);
|
||||||
const [maxBudget, setMaxBudget] = useState(0.0);
|
const [maxBudget, setMaxBudget] = useState(0.0);
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchData = async () => {
|
const fetchData = async () => {
|
||||||
if (!accessToken || !userID || !userRole) {
|
if (!accessToken || !userID || !userRole) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (userRole === "Admin") {
|
if (userRole === "Admin" && userSpend == null) {
|
||||||
try {
|
try {
|
||||||
const globalSpend = await getTotalSpendCall(accessToken);
|
const globalSpend = await getTotalSpendCall(accessToken);
|
||||||
if (globalSpend) {
|
if (globalSpend) {
|
||||||
|
@ -64,13 +66,20 @@ const ViewUserSpend: React.FC<ViewUserSpendProps> = ({ userID, userRole, accessT
|
||||||
fetchData();
|
fetchData();
|
||||||
}, [userRole, accessToken]);
|
}, [userRole, accessToken]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (userSpend !== null) {
|
||||||
|
setSpend(userSpend)
|
||||||
|
}
|
||||||
|
}, [userSpend])
|
||||||
|
|
||||||
const displayMaxBudget = maxBudget !== null ? `$${maxBudget} limit` : "No limit";
|
const displayMaxBudget = maxBudget !== null ? `$${maxBudget} limit` : "No limit";
|
||||||
|
|
||||||
const roundedSpend = spend !== undefined ? spend.toFixed(5) : null;
|
const roundedSpend = spend !== undefined ? spend.toFixed(4) : null;
|
||||||
|
|
||||||
|
console.log(`spend in view user spend: ${spend}`)
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<p className="text-tremor-default text-tremor-content dark:text-dark-tremor-content">Total Spend (across all teams)</p>
|
<p className="text-tremor-default text-tremor-content dark:text-dark-tremor-content">Total Spend </p>
|
||||||
<p className="text-3xl text-tremor-content-strong dark:text-dark-tremor-content-strong font-semibold">${roundedSpend}</p>
|
<p className="text-3xl text-tremor-content-strong dark:text-dark-tremor-content-strong font-semibold">${roundedSpend}</p>
|
||||||
|
|
||||||
</>
|
</>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue