mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
9110 lines
332 KiB
Python
9110 lines
332 KiB
Python
import ast
|
|
import asyncio
|
|
import copy
|
|
import hashlib
|
|
import importlib
|
|
import inspect
|
|
import os
|
|
import platform
|
|
import random
|
|
import re
|
|
import secrets
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
import warnings
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Set, get_args
|
|
|
|
import requests
|
|
|
|
if TYPE_CHECKING:
|
|
from opentelemetry.trace import Span as _Span
|
|
|
|
Span = _Span
|
|
else:
|
|
Span = Any
|
|
|
|
|
|
def showwarning(message, category, filename, lineno, file=None, line=None):
|
|
traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
|
|
if file is not None:
|
|
file.write(traceback_info)
|
|
|
|
|
|
warnings.showwarning = showwarning
|
|
warnings.filterwarnings("default", category=UserWarning)
|
|
|
|
# Your client code here
|
|
|
|
|
|
messages: list = []
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path - for litellm local dev
|
|
|
|
try:
|
|
import logging
|
|
|
|
import backoff
|
|
import fastapi
|
|
import orjson
|
|
import yaml # type: ignore
|
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|
except ImportError as e:
|
|
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
|
|
|
|
import random
|
|
|
|
list_of_messages = [
|
|
"'The thing I wish you improved is...'",
|
|
"'A feature I really want is...'",
|
|
"'The worst thing about this product is...'",
|
|
"'This product would be better if...'",
|
|
"'I don't like how this works...'",
|
|
"'It would help me if you could add...'",
|
|
"'This feature doesn't meet my needs because...'",
|
|
"'I get frustrated when the product...'",
|
|
]
|
|
|
|
|
|
def generate_feedback_box():
|
|
box_width = 60
|
|
|
|
# Select a random message
|
|
message = random.choice(list_of_messages)
|
|
|
|
print() # noqa
|
|
print("\033[1;37m" + "#" + "-" * box_width + "#\033[0m") # noqa
|
|
print("\033[1;37m" + "#" + " " * box_width + "#\033[0m") # noqa
|
|
print("\033[1;37m" + "# {:^59} #\033[0m".format(message)) # noqa
|
|
print( # noqa
|
|
"\033[1;37m"
|
|
+ "# {:^59} #\033[0m".format("https://github.com/BerriAI/litellm/issues/new")
|
|
) # noqa
|
|
print("\033[1;37m" + "#" + " " * box_width + "#\033[0m") # noqa
|
|
print("\033[1;37m" + "#" + "-" * box_width + "#\033[0m") # noqa
|
|
print() # noqa
|
|
print(" Thank you for using LiteLLM! - Krrish & Ishaan") # noqa
|
|
print() # noqa
|
|
print() # noqa
|
|
print() # noqa
|
|
print( # noqa
|
|
"\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m"
|
|
) # noqa
|
|
print() # noqa
|
|
print() # noqa
|
|
|
|
|
|
import pydantic
|
|
|
|
import litellm
|
|
from litellm import (
|
|
CancelBatchRequest,
|
|
CreateBatchRequest,
|
|
CreateFileRequest,
|
|
ListBatchRequest,
|
|
RetrieveBatchRequest,
|
|
)
|
|
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
|
from litellm.caching import DualCache, RedisCache
|
|
from litellm.exceptions import RejectedRequestError
|
|
from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs
|
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
|
from litellm.proxy._types import *
|
|
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
|
router as analytics_router,
|
|
)
|
|
from litellm.proxy.auth.auth_checks import (
|
|
allowed_routes_check,
|
|
common_checks,
|
|
get_actual_routes,
|
|
get_end_user_object,
|
|
get_org_object,
|
|
get_team_object,
|
|
get_user_object,
|
|
log_to_opentelemetry,
|
|
)
|
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
|
from litellm.proxy.auth.litellm_license import LicenseCheck
|
|
from litellm.proxy.auth.model_checks import (
|
|
get_complete_model_list,
|
|
get_key_models,
|
|
get_team_models,
|
|
)
|
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|
|
|
## Import All Misc routes here ##
|
|
from litellm.proxy.caching_routes import router as caching_router
|
|
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
|
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
|
|
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
|
remove_sensitive_info_from_deployment,
|
|
)
|
|
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
|
from litellm.proxy.health_check import perform_health_check
|
|
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
|
_OPTIONAL_PromptInjectionDetection,
|
|
)
|
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
|
router as internal_user_router,
|
|
)
|
|
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
_duration_in_seconds,
|
|
delete_verification_token,
|
|
generate_key_helper_fn,
|
|
)
|
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|
router as key_management_router,
|
|
)
|
|
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|
initialize_pass_through_endpoints,
|
|
)
|
|
from litellm.proxy.secret_managers.aws_secret_manager import (
|
|
load_aws_kms,
|
|
load_aws_secret_manager,
|
|
)
|
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
|
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
|
router as spend_management_router,
|
|
)
|
|
from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload
|
|
from litellm.proxy.utils import (
|
|
DBClient,
|
|
PrismaClient,
|
|
ProxyLogging,
|
|
_cache_user_row,
|
|
_get_projected_spend_over_limit,
|
|
_is_projected_spend_over_limit,
|
|
_is_valid_team_configs,
|
|
decrypt_value,
|
|
encrypt_value,
|
|
get_error_message_str,
|
|
get_instance_fn,
|
|
hash_token,
|
|
html_form,
|
|
missing_keys_html_form,
|
|
reset_budget,
|
|
send_email,
|
|
update_spend,
|
|
)
|
|
from litellm.router import (
|
|
AssistantsTypedDict,
|
|
Deployment,
|
|
LiteLLM_Params,
|
|
ModelGroupInfo,
|
|
)
|
|
from litellm.router import ModelInfo as RouterModelInfo
|
|
from litellm.router import updateDeployment
|
|
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
|
from litellm.types.router import RouterGeneralSettings
|
|
|
|
try:
|
|
from litellm._version import version
|
|
except:
|
|
version = "0.0.0"
|
|
litellm.suppress_debug_info = True
|
|
import json
|
|
import logging
|
|
from typing import Union
|
|
|
|
from fastapi import (
|
|
Depends,
|
|
FastAPI,
|
|
File,
|
|
Form,
|
|
Header,
|
|
HTTPException,
|
|
Path,
|
|
Request,
|
|
Response,
|
|
UploadFile,
|
|
status,
|
|
)
|
|
from fastapi.encoders import jsonable_encoder
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.openapi.utils import get_openapi
|
|
from fastapi.responses import (
|
|
FileResponse,
|
|
JSONResponse,
|
|
ORJSONResponse,
|
|
RedirectResponse,
|
|
StreamingResponse,
|
|
)
|
|
from fastapi.routing import APIRouter
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from fastapi.security.api_key import APIKeyHeader
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
# import enterprise folder
|
|
try:
|
|
# when using litellm cli
|
|
import litellm.proxy.enterprise as enterprise
|
|
except Exception as e:
|
|
# when using litellm docker image
|
|
try:
|
|
import enterprise # type: ignore
|
|
except Exception as e:
|
|
pass
|
|
|
|
_license_check = LicenseCheck()
|
|
premium_user: bool = _license_check.is_premium()
|
|
ui_link = f"/ui/"
|
|
ui_message = (
|
|
f"👉 [```LiteLLM Admin Panel on /ui```]({ui_link}). Create, Edit Keys with SSO"
|
|
)
|
|
custom_swagger_message = f"[**Customize Swagger Docs**](https://docs.litellm.ai/docs/proxy/enterprise#swagger-docs---custom-routes--branding)"
|
|
|
|
### CUSTOM BRANDING [ENTERPRISE FEATURE] ###
|
|
_docs_url = None if os.getenv("NO_DOCS", "False") == "True" else "/"
|
|
_title = os.getenv("DOCS_TITLE", "LiteLLM API") if premium_user else "LiteLLM API"
|
|
_description = (
|
|
os.getenv(
|
|
"DOCS_DESCRIPTION",
|
|
f"Proxy Server to call 100+ LLMs in the OpenAI format. {custom_swagger_message}\n\n{ui_message}",
|
|
)
|
|
if premium_user
|
|
else f"Proxy Server to call 100+ LLMs in the OpenAI format. {custom_swagger_message}\n\n{ui_message}"
|
|
)
|
|
|
|
app = FastAPI(
|
|
docs_url=_docs_url,
|
|
title=_title,
|
|
description=_description,
|
|
version=version,
|
|
root_path=os.environ.get(
|
|
"SERVER_ROOT_PATH", ""
|
|
), # check if user passed root path, FastAPI defaults this value to ""
|
|
)
|
|
|
|
|
|
### CUSTOM API DOCS [ENTERPRISE FEATURE] ###
|
|
# Custom OpenAPI schema generator to include only selected routes
|
|
def custom_openapi():
|
|
if app.openapi_schema:
|
|
return app.openapi_schema
|
|
openapi_schema = get_openapi(
|
|
title=app.title,
|
|
version=app.version,
|
|
description=app.description,
|
|
routes=app.routes,
|
|
)
|
|
# Filter routes to include only specific ones
|
|
openai_routes = LiteLLMRoutes.openai_routes.value
|
|
paths_to_include: dict = {}
|
|
for route in openai_routes:
|
|
paths_to_include[route] = openapi_schema["paths"][route]
|
|
openapi_schema["paths"] = paths_to_include
|
|
app.openapi_schema = openapi_schema
|
|
return app.openapi_schema
|
|
|
|
|
|
if os.getenv("DOCS_FILTERED", "False") == "True" and premium_user:
|
|
app.openapi = custom_openapi # type: ignore
|
|
|
|
|
|
class UserAPIKeyCacheTTLEnum(enum.Enum):
|
|
in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
|
|
|
|
|
|
@app.exception_handler(ProxyException)
|
|
async def openai_exception_handler(request: Request, exc: ProxyException):
|
|
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
|
return JSONResponse(
|
|
status_code=(
|
|
int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
),
|
|
content={
|
|
"error": {
|
|
"message": exc.message,
|
|
"type": exc.type,
|
|
"param": exc.param,
|
|
"code": exc.code,
|
|
}
|
|
},
|
|
)
|
|
|
|
|
|
router = APIRouter()
|
|
origins = ["*"]
|
|
|
|
# get current directory
|
|
try:
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
ui_path = os.path.join(current_dir, "_experimental", "out")
|
|
app.mount("/ui", StaticFiles(directory=ui_path, html=True), name="ui")
|
|
# Iterate through files in the UI directory
|
|
for filename in os.listdir(ui_path):
|
|
if filename.endswith(".html") and filename != "index.html":
|
|
# Create a folder with the same name as the HTML file
|
|
folder_name = os.path.splitext(filename)[0]
|
|
folder_path = os.path.join(ui_path, folder_name)
|
|
os.makedirs(folder_path, exist_ok=True)
|
|
|
|
# Move the HTML file into the folder and rename it to 'index.html'
|
|
src = os.path.join(ui_path, filename)
|
|
dst = os.path.join(folder_path, "index.html")
|
|
os.rename(src, dst)
|
|
except:
|
|
pass
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
from typing import Dict
|
|
|
|
user_api_base = None
|
|
user_model = None
|
|
user_debug = False
|
|
user_max_tokens = None
|
|
user_request_timeout = None
|
|
user_temperature = None
|
|
user_telemetry = True
|
|
user_config = None
|
|
user_headers = None
|
|
user_config_file_path = f"config_{int(time.time())}.yaml"
|
|
local_logging = True # writes logs to a local api_log.json file for debugging
|
|
experimental = False
|
|
#### GLOBAL VARIABLES ####
|
|
llm_router: Optional[litellm.Router] = None
|
|
llm_model_list: Optional[list] = None
|
|
general_settings: dict = {}
|
|
log_file = "api_log.json"
|
|
worker_config = None
|
|
master_key = None
|
|
otel_logging = False
|
|
prisma_client: Optional[PrismaClient] = None
|
|
custom_db_client: Optional[DBClient] = None
|
|
user_api_key_cache = DualCache(
|
|
default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
|
|
)
|
|
redis_usage_cache: Optional[RedisCache] = (
|
|
None # redis cache used for tracking spend, tpm/rpm limits
|
|
)
|
|
user_custom_auth = None
|
|
user_custom_key_generate = None
|
|
use_background_health_checks = None
|
|
use_queue = False
|
|
health_check_interval = None
|
|
health_check_results = {}
|
|
queue: List = []
|
|
litellm_proxy_budget_name = "litellm-proxy-budget"
|
|
litellm_proxy_admin_name = "default_user_id"
|
|
ui_access_mode: Literal["admin", "all"] = "all"
|
|
proxy_budget_rescheduler_min_time = 597
|
|
proxy_budget_rescheduler_max_time = 605
|
|
proxy_batch_write_at = 10 # in seconds
|
|
litellm_master_key_hash = None
|
|
disable_spend_logs = False
|
|
jwt_handler = JWTHandler()
|
|
prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
|
|
store_model_in_db: bool = False
|
|
open_telemetry_logger = None
|
|
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
|
### REDIS QUEUE ###
|
|
async_result = None
|
|
celery_app_conn = None
|
|
celery_fn = None # Redis Queue for handling requests
|
|
### DB WRITER ###
|
|
db_writer_client: Optional[HTTPHandler] = None
|
|
### logger ###
|
|
|
|
|
|
def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
|
|
try:
|
|
return pydantic_obj.model_dump() # type: ignore
|
|
except:
|
|
# if using pydantic v1
|
|
return pydantic_obj.dict()
|
|
|
|
|
|
def get_custom_headers(
|
|
*,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
model_id: Optional[str] = None,
|
|
cache_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
version: Optional[str] = None,
|
|
model_region: Optional[str] = None,
|
|
response_cost: Optional[Union[float, str]] = None,
|
|
fastest_response_batch_completion: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> dict:
|
|
exclude_values = {"", None}
|
|
headers = {
|
|
"x-litellm-model-id": model_id,
|
|
"x-litellm-cache-key": cache_key,
|
|
"x-litellm-model-api-base": api_base,
|
|
"x-litellm-version": version,
|
|
"x-litellm-model-region": model_region,
|
|
"x-litellm-response-cost": str(response_cost),
|
|
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
|
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
|
"x-litellm-fastest_response_batch_completion": (
|
|
str(fastest_response_batch_completion)
|
|
if fastest_response_batch_completion is not None
|
|
else None
|
|
),
|
|
**{k: str(v) for k, v in kwargs.items()},
|
|
}
|
|
try:
|
|
return {
|
|
key: value for key, value in headers.items() if value not in exclude_values
|
|
}
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
|
return {}
|
|
|
|
|
|
async def check_request_disconnection(request: Request, llm_api_call_task):
|
|
"""
|
|
Asynchronously checks if the request is disconnected at regular intervals.
|
|
If the request is disconnected
|
|
- cancel the litellm.router task
|
|
- raises an HTTPException with status code 499 and detail "Client disconnected the request".
|
|
|
|
Parameters:
|
|
- request: Request: The request object to check for disconnection.
|
|
Returns:
|
|
- None
|
|
"""
|
|
|
|
# only run this function for 10 mins -> if these don't get cancelled -> we don't want the server to have many while loops
|
|
start_time = time.time()
|
|
while time.time() - start_time < 600:
|
|
await asyncio.sleep(1)
|
|
if await request.is_disconnected():
|
|
|
|
# cancel the LLM API Call task if any passed - this is passed from individual providers
|
|
# Example OpenAI, Azure, VertexAI etc
|
|
llm_api_call_task.cancel()
|
|
|
|
raise HTTPException(
|
|
status_code=499,
|
|
detail="Client disconnected the request",
|
|
)
|
|
|
|
|
|
def prisma_setup(database_url: Optional[str]):
|
|
global prisma_client, proxy_logging_obj, user_api_key_cache
|
|
|
|
if database_url is not None:
|
|
try:
|
|
prisma_client = PrismaClient(
|
|
database_url=database_url, proxy_logging_obj=proxy_logging_obj
|
|
)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
|
|
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
|
if use_azure_key_vault is False:
|
|
return
|
|
|
|
try:
|
|
from azure.identity import ClientSecretCredential
|
|
from azure.keyvault.secrets import SecretClient
|
|
|
|
# Set your Azure Key Vault URI
|
|
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
|
|
|
|
# Set your Azure AD application/client ID, client secret, and tenant ID
|
|
client_id = os.getenv("AZURE_CLIENT_ID", None)
|
|
client_secret = os.getenv("AZURE_CLIENT_SECRET", None)
|
|
tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
|
|
|
if (
|
|
KVUri is not None
|
|
and client_id is not None
|
|
and client_secret is not None
|
|
and tenant_id is not None
|
|
):
|
|
# Initialize the ClientSecretCredential
|
|
credential = ClientSecretCredential(
|
|
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
|
|
)
|
|
|
|
# Create the SecretClient using the credential
|
|
client = SecretClient(vault_url=KVUri, credential=credential)
|
|
|
|
litellm.secret_manager_client = client
|
|
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
|
|
else:
|
|
raise Exception(
|
|
f"Missing KVUri or client_id or client_secret or tenant_id from environment"
|
|
)
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(
|
|
"Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`"
|
|
)
|
|
|
|
|
|
def cost_tracking():
|
|
global prisma_client, custom_db_client
|
|
if prisma_client is not None or custom_db_client is not None:
|
|
if isinstance(litellm.success_callback, list):
|
|
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
|
if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore
|
|
litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore
|
|
|
|
|
|
async def _PROXY_failure_handler(
|
|
kwargs, # kwargs to completion
|
|
completion_response: litellm.ModelResponse, # response from completion
|
|
start_time=None,
|
|
end_time=None, # start/end time for completion
|
|
):
|
|
global prisma_client
|
|
if prisma_client is not None:
|
|
verbose_proxy_logger.debug(
|
|
"inside _PROXY_failure_handler kwargs=", extra=kwargs
|
|
)
|
|
|
|
_exception = kwargs.get("exception")
|
|
_exception_type = _exception.__class__.__name__
|
|
_model = kwargs.get("model", None)
|
|
|
|
_optional_params = kwargs.get("optional_params", {})
|
|
_optional_params = copy.deepcopy(_optional_params)
|
|
|
|
for k, v in _optional_params.items():
|
|
v = str(v)
|
|
v = v[:100]
|
|
|
|
_status_code = "500"
|
|
try:
|
|
_status_code = str(_exception.status_code)
|
|
except:
|
|
# Don't let this fail logging the exception to the dB
|
|
pass
|
|
|
|
_litellm_params = kwargs.get("litellm_params", {}) or {}
|
|
_metadata = _litellm_params.get("metadata", {}) or {}
|
|
_model_id = _metadata.get("model_info", {}).get("id", "")
|
|
_model_group = _metadata.get("model_group", "")
|
|
api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params)
|
|
_exception_string = str(_exception)
|
|
|
|
error_log = LiteLLM_ErrorLogs(
|
|
request_id=str(uuid.uuid4()),
|
|
model_group=_model_group,
|
|
model_id=_model_id,
|
|
litellm_model_name=kwargs.get("model"),
|
|
request_kwargs=_optional_params,
|
|
api_base=api_base,
|
|
exception_type=_exception_type,
|
|
status_code=_status_code,
|
|
exception_string=_exception_string,
|
|
startTime=kwargs.get("start_time"),
|
|
endTime=kwargs.get("end_time"),
|
|
)
|
|
|
|
# helper function to convert to dict on pydantic v2 & v1
|
|
error_log_dict = _get_pydantic_json_dict(error_log)
|
|
error_log_dict["request_kwargs"] = json.dumps(error_log_dict["request_kwargs"])
|
|
|
|
await prisma_client.db.litellm_errorlogs.create(
|
|
data=error_log_dict # type: ignore
|
|
)
|
|
|
|
pass
|
|
|
|
|
|
async def _PROXY_track_cost_callback(
|
|
kwargs, # kwargs to completion
|
|
completion_response: litellm.ModelResponse, # response from completion
|
|
start_time=None,
|
|
end_time=None, # start/end time for completion
|
|
):
|
|
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
|
|
global prisma_client, custom_db_client
|
|
try:
|
|
# check if it has collected an entire stream response
|
|
verbose_proxy_logger.debug("Proxy: In track_cost_callback for: %s", kwargs)
|
|
verbose_proxy_logger.debug(
|
|
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
|
)
|
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
|
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
|
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
|
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
|
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
|
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
|
|
key_alias = kwargs["litellm_params"]["metadata"].get("user_api_key_alias", None)
|
|
end_user_max_budget = kwargs["litellm_params"]["metadata"].get(
|
|
"user_api_end_user_max_budget", None
|
|
)
|
|
if kwargs.get("response_cost", None) is not None:
|
|
response_cost = kwargs["response_cost"]
|
|
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
|
"user_api_key", None
|
|
)
|
|
|
|
if kwargs.get("cache_hit", False) == True:
|
|
response_cost = 0.0
|
|
verbose_proxy_logger.info(
|
|
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
|
)
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
|
|
)
|
|
if user_api_key is not None or user_id is not None or team_id is not None:
|
|
## UPDATE DATABASE
|
|
await update_database(
|
|
token=user_api_key,
|
|
response_cost=response_cost,
|
|
user_id=user_id,
|
|
end_user_id=end_user_id,
|
|
team_id=team_id,
|
|
kwargs=kwargs,
|
|
completion_response=completion_response,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
org_id=org_id,
|
|
)
|
|
|
|
await update_cache(
|
|
token=user_api_key,
|
|
user_id=user_id,
|
|
end_user_id=end_user_id,
|
|
response_cost=response_cost,
|
|
team_id=team_id,
|
|
)
|
|
|
|
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
|
token=user_api_key,
|
|
key_alias=key_alias,
|
|
end_user_id=end_user_id,
|
|
response_cost=response_cost,
|
|
max_budget=end_user_max_budget,
|
|
)
|
|
else:
|
|
raise Exception(
|
|
"User API key and team id and user id missing from custom callback."
|
|
)
|
|
else:
|
|
if kwargs["stream"] != True or (
|
|
kwargs["stream"] == True and "complete_streaming_response" in kwargs
|
|
):
|
|
raise Exception(
|
|
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
|
)
|
|
except Exception as e:
|
|
error_msg = f"error in tracking cost callback - {traceback.format_exc()}"
|
|
model = kwargs.get("model", "")
|
|
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
|
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
|
|
asyncio.create_task(
|
|
proxy_logging_obj.failed_tracking_alert(
|
|
error_message=error_msg,
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug("error in tracking cost callback - %s", e)
|
|
|
|
|
|
def error_tracking():
|
|
global prisma_client, custom_db_client
|
|
if prisma_client is not None or custom_db_client is not None:
|
|
if isinstance(litellm.failure_callback, list):
|
|
verbose_proxy_logger.debug("setting litellm failure callback to track cost")
|
|
if (_PROXY_failure_handler) not in litellm.failure_callback: # type: ignore
|
|
litellm.failure_callback.append(_PROXY_failure_handler) # type: ignore
|
|
|
|
|
|
def _set_spend_logs_payload(
|
|
payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None
|
|
):
|
|
if prisma_client is not None and spend_logs_url is not None:
|
|
if isinstance(payload["startTime"], datetime):
|
|
payload["startTime"] = payload["startTime"].isoformat()
|
|
if isinstance(payload["endTime"], datetime):
|
|
payload["endTime"] = payload["endTime"].isoformat()
|
|
prisma_client.spend_log_transactions.append(payload)
|
|
elif prisma_client is not None:
|
|
prisma_client.spend_log_transactions.append(payload)
|
|
return prisma_client
|
|
|
|
|
|
async def update_database(
|
|
token,
|
|
response_cost,
|
|
user_id=None,
|
|
end_user_id=None,
|
|
team_id=None,
|
|
kwargs=None,
|
|
completion_response=None,
|
|
start_time=None,
|
|
end_time=None,
|
|
org_id=None,
|
|
):
|
|
try:
|
|
global prisma_client
|
|
verbose_proxy_logger.info(
|
|
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
|
|
)
|
|
if token is not None and isinstance(token, str) and token.startswith("sk-"):
|
|
hashed_token = hash_token(token=token)
|
|
else:
|
|
hashed_token = token
|
|
|
|
### UPDATE USER SPEND ###
|
|
async def _update_user_db():
|
|
"""
|
|
- Update that user's row
|
|
- Update litellm-proxy-budget row (global proxy spend)
|
|
"""
|
|
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
|
|
existing_token_obj = await user_api_key_cache.async_get_cache(
|
|
key=hashed_token
|
|
)
|
|
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
|
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
|
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
|
data_list = []
|
|
try:
|
|
if prisma_client is not None: # update
|
|
user_ids = [user_id]
|
|
if (
|
|
litellm.max_budget > 0
|
|
): # track global proxy budget, if user set max budget
|
|
user_ids.append(litellm_proxy_budget_name)
|
|
### KEY CHANGE ###
|
|
for _id in user_ids:
|
|
if _id is not None:
|
|
prisma_client.user_list_transactons[_id] = (
|
|
response_cost
|
|
+ prisma_client.user_list_transactons.get(_id, 0)
|
|
)
|
|
if end_user_id is not None:
|
|
prisma_client.end_user_list_transactons[end_user_id] = (
|
|
response_cost
|
|
+ prisma_client.end_user_list_transactons.get(
|
|
end_user_id, 0
|
|
)
|
|
)
|
|
except Exception as e:
|
|
verbose_proxy_logger.info(
|
|
"\033[91m"
|
|
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
|
|
### UPDATE KEY SPEND ###
|
|
async def _update_key_db():
|
|
try:
|
|
verbose_proxy_logger.debug(
|
|
f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}."
|
|
)
|
|
if hashed_token is None:
|
|
return
|
|
if prisma_client is not None:
|
|
prisma_client.key_list_transactons[hashed_token] = (
|
|
response_cost
|
|
+ prisma_client.key_list_transactons.get(hashed_token, 0)
|
|
)
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
f"Update Key DB Call failed to execute - {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
raise e
|
|
|
|
### UPDATE SPEND LOGS ###
|
|
async def _insert_spend_log_to_db():
|
|
try:
|
|
global prisma_client
|
|
if prisma_client is not None:
|
|
# Helper to generate payload to log
|
|
payload = get_logging_payload(
|
|
kwargs=kwargs,
|
|
response_obj=completion_response,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
end_user_id=end_user_id,
|
|
)
|
|
|
|
payload["spend"] = response_cost
|
|
prisma_client = _set_spend_logs_payload(
|
|
payload=payload,
|
|
spend_logs_url=os.getenv("SPEND_LOGS_URL"),
|
|
prisma_client=prisma_client,
|
|
)
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(
|
|
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
raise e
|
|
|
|
### UPDATE TEAM SPEND ###
|
|
async def _update_team_db():
|
|
try:
|
|
verbose_proxy_logger.debug(
|
|
f"adding spend to team db. Response cost: {response_cost}. team_id: {team_id}."
|
|
)
|
|
if team_id is None:
|
|
verbose_proxy_logger.debug(
|
|
"track_cost_callback: team_id is None. Not tracking spend for team"
|
|
)
|
|
return
|
|
if prisma_client is not None:
|
|
prisma_client.team_list_transactons[team_id] = (
|
|
response_cost
|
|
+ prisma_client.team_list_transactons.get(team_id, 0)
|
|
)
|
|
|
|
try:
|
|
# Track spend of the team member within this team
|
|
# key is "team_id::<value>::user_id::<value>"
|
|
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
|
|
prisma_client.team_member_list_transactons[team_member_key] = (
|
|
response_cost
|
|
+ prisma_client.team_member_list_transactons.get(
|
|
team_member_key, 0
|
|
)
|
|
)
|
|
except:
|
|
pass
|
|
except Exception as e:
|
|
verbose_proxy_logger.info(
|
|
f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
raise e
|
|
|
|
### UPDATE ORG SPEND ###
|
|
async def _update_org_db():
|
|
try:
|
|
verbose_proxy_logger.debug(
|
|
"adding spend to org db. Response cost: {}. org_id: {}.".format(
|
|
response_cost, org_id
|
|
)
|
|
)
|
|
if org_id is None:
|
|
verbose_proxy_logger.debug(
|
|
"track_cost_callback: org_id is None. Not tracking spend for org"
|
|
)
|
|
return
|
|
if prisma_client is not None:
|
|
prisma_client.org_list_transactons[org_id] = (
|
|
response_cost
|
|
+ prisma_client.org_list_transactons.get(org_id, 0)
|
|
)
|
|
except Exception as e:
|
|
verbose_proxy_logger.info(
|
|
f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
|
)
|
|
raise e
|
|
|
|
asyncio.create_task(_update_user_db())
|
|
asyncio.create_task(_update_key_db())
|
|
asyncio.create_task(_update_team_db())
|
|
asyncio.create_task(_update_org_db())
|
|
# asyncio.create_task(_insert_spend_log_to_db())
|
|
if disable_spend_logs == False:
|
|
await _insert_spend_log_to_db()
|
|
|
|
verbose_proxy_logger.debug("Runs spend update on all tables")
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(
|
|
f"Error updating Prisma database: {traceback.format_exc()}"
|
|
)
|
|
|
|
|
|
async def update_cache(
|
|
token: Optional[str],
|
|
user_id: Optional[str],
|
|
end_user_id: Optional[str],
|
|
team_id: Optional[str],
|
|
response_cost: Optional[float],
|
|
):
|
|
"""
|
|
Use this to update the cache with new user spend.
|
|
|
|
Put any alerting logic in here.
|
|
"""
|
|
|
|
### UPDATE KEY SPEND ###
|
|
async def _update_key_cache(token: str, response_cost: float):
|
|
# Fetch the existing cost for the given token
|
|
if isinstance(token, str) and token.startswith("sk-"):
|
|
hashed_token = hash_token(token=token)
|
|
else:
|
|
hashed_token = token
|
|
verbose_proxy_logger.debug("_update_key_cache: hashed_token=%s", hashed_token)
|
|
existing_spend_obj: LiteLLM_VerificationTokenView = await user_api_key_cache.async_get_cache(key=hashed_token) # type: ignore
|
|
verbose_proxy_logger.debug(
|
|
f"_update_key_cache: existing_spend_obj={existing_spend_obj}"
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"_update_key_cache: existing spend: {existing_spend_obj}"
|
|
)
|
|
if existing_spend_obj is None:
|
|
return
|
|
else:
|
|
existing_spend = existing_spend_obj.spend
|
|
# Calculate the new cost by adding the existing cost and response_cost
|
|
new_spend = existing_spend + response_cost
|
|
|
|
## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
|
|
soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown
|
|
if (
|
|
existing_spend_obj.soft_budget_cooldown == False
|
|
and existing_spend_obj.litellm_budget_table is not None
|
|
and (
|
|
_is_projected_spend_over_limit(
|
|
current_spend=new_spend,
|
|
soft_budget_limit=existing_spend_obj.litellm_budget_table[
|
|
"soft_budget"
|
|
],
|
|
)
|
|
== True
|
|
)
|
|
):
|
|
projected_spend, projected_exceeded_date = _get_projected_spend_over_limit(
|
|
current_spend=new_spend,
|
|
soft_budget_limit=existing_spend_obj.litellm_budget_table.get(
|
|
"soft_budget", None
|
|
),
|
|
) # type: ignore
|
|
soft_limit = existing_spend_obj.litellm_budget_table.get(
|
|
"soft_budget", float("inf")
|
|
)
|
|
call_info = CallInfo(
|
|
token=existing_spend_obj.token or "",
|
|
spend=new_spend,
|
|
key_alias=existing_spend_obj.key_alias,
|
|
max_budget=soft_limit,
|
|
user_id=existing_spend_obj.user_id,
|
|
projected_spend=projected_spend,
|
|
projected_exceeded_date=projected_exceeded_date,
|
|
)
|
|
# alert user
|
|
asyncio.create_task(
|
|
proxy_logging_obj.budget_alerts(
|
|
type="projected_limit_exceeded",
|
|
user_info=call_info,
|
|
)
|
|
)
|
|
# set cooldown on alert
|
|
soft_budget_cooldown = True
|
|
|
|
if (
|
|
existing_spend_obj is not None
|
|
and getattr(existing_spend_obj, "team_spend", None) is not None
|
|
):
|
|
existing_team_spend = existing_spend_obj.team_spend or 0
|
|
# Calculate the new cost by adding the existing cost and response_cost
|
|
existing_spend_obj.team_spend = existing_team_spend + response_cost
|
|
|
|
if (
|
|
existing_spend_obj is not None
|
|
and getattr(existing_spend_obj, "team_member_spend", None) is not None
|
|
):
|
|
existing_team_member_spend = existing_spend_obj.team_member_spend or 0
|
|
# Calculate the new cost by adding the existing cost and response_cost
|
|
existing_spend_obj.team_member_spend = (
|
|
existing_team_member_spend + response_cost
|
|
)
|
|
|
|
# Update the cost column for the given token
|
|
existing_spend_obj.spend = new_spend
|
|
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
|
|
|
### UPDATE USER SPEND ###
|
|
async def _update_user_cache():
|
|
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
|
user_ids = [user_id]
|
|
try:
|
|
for _id in user_ids:
|
|
# Fetch the existing cost for the given user
|
|
if _id is None:
|
|
continue
|
|
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
|
|
if existing_spend_obj is None:
|
|
# do nothing if there is no cache value
|
|
return
|
|
verbose_proxy_logger.debug(
|
|
f"_update_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
|
|
)
|
|
|
|
if isinstance(existing_spend_obj, dict):
|
|
existing_spend = existing_spend_obj["spend"]
|
|
else:
|
|
existing_spend = existing_spend_obj.spend
|
|
# Calculate the new cost by adding the existing cost and response_cost
|
|
new_spend = existing_spend + response_cost
|
|
|
|
# Update the cost column for the given user
|
|
if isinstance(existing_spend_obj, dict):
|
|
existing_spend_obj["spend"] = new_spend
|
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
|
else:
|
|
existing_spend_obj.spend = new_spend
|
|
user_api_key_cache.set_cache(
|
|
key=_id, value=existing_spend_obj.json()
|
|
)
|
|
## UPDATE GLOBAL PROXY ##
|
|
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
|
key="{}:spend".format(litellm_proxy_admin_name)
|
|
)
|
|
if global_proxy_spend is None:
|
|
# do nothing if not in cache
|
|
return
|
|
elif response_cost is not None and global_proxy_spend is not None:
|
|
increment = global_proxy_spend + response_cost
|
|
await user_api_key_cache.async_set_cache(
|
|
key="{}:spend".format(litellm_proxy_admin_name), value=increment
|
|
)
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(
|
|
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
|
)
|
|
|
|
### UPDATE END-USER SPEND ###
|
|
async def _update_end_user_cache():
|
|
if end_user_id is None or response_cost is None:
|
|
return
|
|
|
|
_id = "end_user_id:{}".format(end_user_id)
|
|
try:
|
|
# Fetch the existing cost for the given user
|
|
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
|
|
if existing_spend_obj is None:
|
|
# if user does not exist in LiteLLM_UserTable, create a new user
|
|
# do nothing if end-user not in api key cache
|
|
return
|
|
verbose_proxy_logger.debug(
|
|
f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
|
|
)
|
|
if existing_spend_obj is None:
|
|
existing_spend = 0
|
|
else:
|
|
if isinstance(existing_spend_obj, dict):
|
|
existing_spend = existing_spend_obj["spend"]
|
|
else:
|
|
existing_spend = existing_spend_obj.spend
|
|
# Calculate the new cost by adding the existing cost and response_cost
|
|
new_spend = existing_spend + response_cost
|
|
|
|
# Update the cost column for the given user
|
|
if isinstance(existing_spend_obj, dict):
|
|
existing_spend_obj["spend"] = new_spend
|
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
|
else:
|
|
existing_spend_obj.spend = new_spend
|
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
|
)
|
|
|
|
### UPDATE TEAM SPEND ###
|
|
async def _update_team_cache():
|
|
if team_id is None or response_cost is None:
|
|
return
|
|
|
|
_id = "team_id:{}".format(team_id)
|
|
try:
|
|
# Fetch the existing cost for the given user
|
|
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
|
|
await user_api_key_cache.async_get_cache(key=_id)
|
|
)
|
|
if existing_spend_obj is None:
|
|
# do nothing if team not in api key cache
|
|
return
|
|
verbose_proxy_logger.debug(
|
|
f"_update_team_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
|
|
)
|
|
if existing_spend_obj is None:
|
|
existing_spend: Optional[float] = 0.0
|
|
else:
|
|
if isinstance(existing_spend_obj, dict):
|
|
existing_spend = existing_spend_obj["spend"]
|
|
else:
|
|
existing_spend = existing_spend_obj.spend
|
|
|
|
if existing_spend is None:
|
|
existing_spend = 0.0
|
|
# Calculate the new cost by adding the existing cost and response_cost
|
|
new_spend = existing_spend + response_cost
|
|
|
|
# Update the cost column for the given user
|
|
if isinstance(existing_spend_obj, dict):
|
|
existing_spend_obj["spend"] = new_spend
|
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
|
else:
|
|
existing_spend_obj.spend = new_spend
|
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
|
)
|
|
|
|
if token is not None and response_cost is not None:
|
|
asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost))
|
|
|
|
if user_id is not None:
|
|
asyncio.create_task(_update_user_cache())
|
|
|
|
if end_user_id is not None:
|
|
asyncio.create_task(_update_end_user_cache())
|
|
|
|
if team_id is not None:
|
|
asyncio.create_task(_update_team_cache())
|
|
|
|
|
|
def run_ollama_serve():
|
|
try:
|
|
command = ["ollama", "serve"]
|
|
|
|
with open(os.devnull, "w") as devnull:
|
|
process = subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(
|
|
f"""
|
|
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
|
"""
|
|
)
|
|
|
|
|
|
async def _run_background_health_check():
|
|
"""
|
|
Periodically run health checks in the background on the endpoints.
|
|
|
|
Update health_check_results, based on this.
|
|
"""
|
|
global health_check_results, llm_model_list, health_check_interval
|
|
|
|
# make 1 deep copy of llm_model_list -> use this for all background health checks
|
|
_llm_model_list = copy.deepcopy(llm_model_list)
|
|
|
|
while True:
|
|
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
|
model_list=_llm_model_list
|
|
)
|
|
|
|
# Update the global variable with the health check results
|
|
health_check_results["healthy_endpoints"] = healthy_endpoints
|
|
health_check_results["unhealthy_endpoints"] = unhealthy_endpoints
|
|
health_check_results["healthy_count"] = len(healthy_endpoints)
|
|
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
|
|
|
|
await asyncio.sleep(health_check_interval)
|
|
|
|
|
|
class ProxyConfig:
|
|
"""
|
|
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def is_yaml(self, config_file_path: str) -> bool:
|
|
if not os.path.isfile(config_file_path):
|
|
return False
|
|
|
|
_, file_extension = os.path.splitext(config_file_path)
|
|
return file_extension.lower() == ".yaml" or file_extension.lower() == ".yml"
|
|
|
|
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
|
global prisma_client, user_config_file_path
|
|
|
|
file_path = config_file_path or user_config_file_path
|
|
if config_file_path is not None:
|
|
user_config_file_path = config_file_path
|
|
# Load existing config
|
|
## Yaml
|
|
if os.path.exists(f"{file_path}"):
|
|
with open(f"{file_path}", "r") as config_file:
|
|
config = yaml.safe_load(config_file)
|
|
else:
|
|
config = {
|
|
"model_list": [],
|
|
"general_settings": {},
|
|
"router_settings": {},
|
|
"litellm_settings": {},
|
|
}
|
|
|
|
## DB
|
|
if prisma_client is not None and (
|
|
general_settings.get("store_model_in_db", False) == True
|
|
):
|
|
_tasks = []
|
|
keys = [
|
|
"general_settings",
|
|
"router_settings",
|
|
"litellm_settings",
|
|
"environment_variables",
|
|
]
|
|
for k in keys:
|
|
response = prisma_client.get_generic_data(
|
|
key="param_name", value=k, table_name="config"
|
|
)
|
|
_tasks.append(response)
|
|
|
|
responses = await asyncio.gather(*_tasks)
|
|
for response in responses:
|
|
if response is not None:
|
|
param_name = getattr(response, "param_name", None)
|
|
param_value = getattr(response, "param_value", None)
|
|
if param_name is not None and param_value is not None:
|
|
# check if param_name is already in the config
|
|
if param_name in config:
|
|
if isinstance(config[param_name], dict):
|
|
config[param_name].update(param_value)
|
|
else:
|
|
config[param_name] = param_value
|
|
else:
|
|
# if it's not in the config - then add it
|
|
config[param_name] = param_value
|
|
|
|
return config
|
|
|
|
async def save_config(self, new_config: dict):
|
|
global prisma_client, general_settings, user_config_file_path
|
|
# Load existing config
|
|
## DB - writes valid config to db
|
|
"""
|
|
- Do not write restricted params like 'api_key' to the database
|
|
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
|
|
"""
|
|
if prisma_client is not None and (
|
|
general_settings.get("store_model_in_db", False) == True
|
|
):
|
|
# if using - db for config - models are in ModelTable
|
|
new_config.pop("model_list", None)
|
|
await prisma_client.insert_data(data=new_config, table_name="config")
|
|
else:
|
|
# Save the updated config - if user is not using a dB
|
|
## YAML
|
|
with open(f"{user_config_file_path}", "w") as config_file:
|
|
yaml.dump(new_config, config_file, default_flow_style=False)
|
|
|
|
async def load_team_config(self, team_id: str):
|
|
"""
|
|
- for a given team id
|
|
- return the relevant completion() call params
|
|
"""
|
|
# load existing config
|
|
config = await self.get_config()
|
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
|
litellm_settings = config.get("litellm_settings", {})
|
|
all_teams_config = litellm_settings.get("default_team_settings", None)
|
|
team_config: dict = {}
|
|
if all_teams_config is None:
|
|
return team_config
|
|
for team in all_teams_config:
|
|
if "team_id" not in team:
|
|
raise Exception(f"team_id missing from team: {team}")
|
|
if team_id == team["team_id"]:
|
|
team_config = team
|
|
break
|
|
for k, v in team_config.items():
|
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
|
team_config[k] = litellm.get_secret(v)
|
|
return team_config
|
|
|
|
def _init_cache(
|
|
self,
|
|
cache_params: dict,
|
|
):
|
|
global redis_usage_cache
|
|
from litellm import Cache
|
|
|
|
if "default_in_memory_ttl" in cache_params:
|
|
litellm.default_in_memory_ttl = cache_params["default_in_memory_ttl"]
|
|
|
|
if "default_redis_ttl" in cache_params:
|
|
litellm.default_redis_ttl = cache_params["default_in_redis_ttl"]
|
|
|
|
litellm.cache = Cache(**cache_params)
|
|
|
|
if litellm.cache is not None and isinstance(litellm.cache.cache, RedisCache):
|
|
## INIT PROXY REDIS USAGE CLIENT ##
|
|
redis_usage_cache = litellm.cache.cache
|
|
|
|
async def load_config(
|
|
self, router: Optional[litellm.Router], config_file_path: str
|
|
):
|
|
"""
|
|
Load config values into proxy global state
|
|
"""
|
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger
|
|
|
|
# Load existing config
|
|
config = await self.get_config(config_file_path=config_file_path)
|
|
## PRINT YAML FOR CONFIRMING IT WORKS
|
|
printed_yaml = copy.deepcopy(config)
|
|
printed_yaml.pop("environment_variables", None)
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
|
)
|
|
|
|
## ENVIRONMENT VARIABLES
|
|
environment_variables = config.get("environment_variables", None)
|
|
if environment_variables:
|
|
for key, value in environment_variables.items():
|
|
os.environ[key] = value
|
|
|
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
|
litellm_settings = config.get("litellm_settings", None)
|
|
if litellm_settings is None:
|
|
litellm_settings = {}
|
|
if litellm_settings:
|
|
# ANSI escape code for blue text
|
|
blue_color_code = "\033[94m"
|
|
reset_color_code = "\033[0m"
|
|
for key, value in litellm_settings.items():
|
|
if key == "cache" and value == True:
|
|
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
|
|
from litellm.caching import Cache
|
|
|
|
cache_params = {}
|
|
if "cache_params" in litellm_settings:
|
|
cache_params_in_config = litellm_settings["cache_params"]
|
|
# overwrie cache_params with cache_params_in_config
|
|
cache_params.update(cache_params_in_config)
|
|
|
|
cache_type = cache_params.get("type", "redis")
|
|
|
|
verbose_proxy_logger.debug("passed cache type=%s", cache_type)
|
|
|
|
if (
|
|
cache_type == "redis" or cache_type == "redis-semantic"
|
|
) and len(cache_params.keys()) == 0:
|
|
cache_host = litellm.get_secret("REDIS_HOST", None)
|
|
cache_port = litellm.get_secret("REDIS_PORT", None)
|
|
cache_password = None
|
|
cache_params.update(
|
|
{
|
|
"type": cache_type,
|
|
"host": cache_host,
|
|
"port": cache_port,
|
|
}
|
|
)
|
|
|
|
if litellm.get_secret("REDIS_PASSWORD", None) is not None:
|
|
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
|
cache_params.update(
|
|
{
|
|
"password": cache_password,
|
|
}
|
|
)
|
|
|
|
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
|
verbose_proxy_logger.debug(
|
|
"%sCache Type:%s %s",
|
|
blue_color_code,
|
|
reset_color_code,
|
|
cache_type,
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
"%sCache Host:%s %s",
|
|
blue_color_code,
|
|
reset_color_code,
|
|
cache_host,
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
"%sCache Port:%s %s",
|
|
blue_color_code,
|
|
reset_color_code,
|
|
cache_port,
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
"%sCache Password:%s %s",
|
|
blue_color_code,
|
|
reset_color_code,
|
|
cache_password,
|
|
)
|
|
if cache_type == "redis-semantic":
|
|
# by default this should always be async
|
|
cache_params.update({"redis_semantic_cache_use_async": True})
|
|
|
|
# users can pass os.environ/ variables on the proxy - we should read them from the env
|
|
for key, value in cache_params.items():
|
|
if type(value) is str and value.startswith("os.environ/"):
|
|
cache_params[key] = litellm.get_secret(value)
|
|
|
|
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
|
self._init_cache(cache_params=cache_params)
|
|
if litellm.cache is not None:
|
|
verbose_proxy_logger.debug( # noqa
|
|
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
|
)
|
|
elif key == "cache" and value == False:
|
|
pass
|
|
elif key == "guardrails":
|
|
if premium_user is not True:
|
|
raise ValueError(
|
|
"Trying to use `guardrails` on config.yaml "
|
|
+ CommonProxyErrors.not_premium_user.value
|
|
)
|
|
|
|
initialize_guardrails(
|
|
guardrails_config=value,
|
|
premium_user=premium_user,
|
|
config_file_path=config_file_path,
|
|
litellm_settings=litellm_settings,
|
|
)
|
|
elif key == "callbacks":
|
|
|
|
initialize_callbacks_on_proxy(
|
|
value=value,
|
|
premium_user=premium_user,
|
|
config_file_path=config_file_path,
|
|
litellm_settings=litellm_settings,
|
|
)
|
|
|
|
elif key == "post_call_rules":
|
|
litellm.post_call_rules = [
|
|
get_instance_fn(value=value, config_file_path=config_file_path)
|
|
]
|
|
verbose_proxy_logger.debug(
|
|
f"litellm.post_call_rules: {litellm.post_call_rules}"
|
|
)
|
|
elif key == "success_callback":
|
|
litellm.success_callback = []
|
|
|
|
# initialize success callbacks
|
|
for callback in value:
|
|
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
|
if "." in callback:
|
|
litellm.success_callback.append(
|
|
get_instance_fn(value=callback)
|
|
)
|
|
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
|
else:
|
|
litellm.success_callback.append(callback)
|
|
if "prometheus" in callback:
|
|
verbose_proxy_logger.debug(
|
|
"Starting Prometheus Metrics on /metrics"
|
|
)
|
|
from prometheus_client import make_asgi_app
|
|
|
|
# Add prometheus asgi middleware to route /metrics requests
|
|
metrics_app = make_asgi_app()
|
|
app.mount("/metrics", metrics_app)
|
|
print( # noqa
|
|
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
|
|
) # noqa
|
|
elif key == "failure_callback":
|
|
litellm.failure_callback = []
|
|
|
|
# initialize success callbacks
|
|
for callback in value:
|
|
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
|
if "." in callback:
|
|
litellm.failure_callback.append(
|
|
get_instance_fn(value=callback)
|
|
)
|
|
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
|
else:
|
|
litellm.failure_callback.append(callback)
|
|
print( # noqa
|
|
f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}"
|
|
) # noqa
|
|
elif key == "cache_params":
|
|
# this is set in the cache branch
|
|
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
|
pass
|
|
elif key == "default_team_settings":
|
|
for idx, team_setting in enumerate(
|
|
value
|
|
): # run through pydantic validation
|
|
try:
|
|
TeamDefaultSettings(**team_setting)
|
|
except:
|
|
raise Exception(
|
|
f"team_id missing from default_team_settings at index={idx}\npassed in value={team_setting}"
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"{blue_color_code} setting litellm.{key}={value}{reset_color_code}"
|
|
)
|
|
setattr(litellm, key, value)
|
|
elif key == "upperbound_key_generate_params":
|
|
if value is not None and isinstance(value, dict):
|
|
for _k, _v in value.items():
|
|
if isinstance(_v, str) and _v.startswith("os.environ/"):
|
|
value[_k] = litellm.get_secret(_v)
|
|
litellm.upperbound_key_generate_params = (
|
|
LiteLLM_UpperboundKeyGenerateParams(**value)
|
|
)
|
|
else:
|
|
raise Exception(
|
|
f"Invalid value set for upperbound_key_generate_params - value={value}"
|
|
)
|
|
else:
|
|
verbose_proxy_logger.debug(
|
|
f"{blue_color_code} setting litellm.{key}={value}{reset_color_code}"
|
|
)
|
|
setattr(litellm, key, value)
|
|
|
|
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
|
general_settings = config.get("general_settings", {})
|
|
if general_settings is None:
|
|
general_settings = {}
|
|
if general_settings:
|
|
### LOAD SECRET MANAGER ###
|
|
key_management_system = general_settings.get("key_management_system", None)
|
|
if key_management_system is not None:
|
|
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
|
|
### LOAD FROM AZURE KEY VAULT ###
|
|
load_from_azure_key_vault(use_azure_key_vault=True)
|
|
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
|
### LOAD FROM GOOGLE KMS ###
|
|
load_google_kms(use_google_kms=True)
|
|
elif (
|
|
key_management_system
|
|
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
|
):
|
|
### LOAD FROM AWS SECRET MANAGER ###
|
|
load_aws_secret_manager(use_aws_secret_manager=True)
|
|
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
|
load_aws_kms(use_aws_kms=True)
|
|
else:
|
|
raise ValueError("Invalid Key Management System selected")
|
|
key_management_settings = general_settings.get(
|
|
"key_management_settings", None
|
|
)
|
|
if key_management_settings is not None:
|
|
litellm._key_management_settings = KeyManagementSettings(
|
|
**key_management_settings
|
|
)
|
|
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
|
|
use_google_kms = general_settings.get("use_google_kms", False)
|
|
load_google_kms(use_google_kms=use_google_kms)
|
|
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
|
|
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
|
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
|
### ALERTING ###
|
|
|
|
proxy_logging_obj.update_values(
|
|
alerting=general_settings.get("alerting", None),
|
|
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
|
alert_types=general_settings.get("alert_types", None),
|
|
alerting_args=general_settings.get("alerting_args", None),
|
|
redis_cache=redis_usage_cache,
|
|
)
|
|
### CONNECT TO DATABASE ###
|
|
database_url = general_settings.get("database_url", None)
|
|
if database_url and database_url.startswith("os.environ/"):
|
|
verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
|
|
database_url = litellm.get_secret(database_url)
|
|
verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
|
|
### MASTER KEY ###
|
|
master_key = general_settings.get(
|
|
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
|
)
|
|
|
|
if master_key and master_key.startswith("os.environ/"):
|
|
master_key = litellm.get_secret(master_key)
|
|
if not isinstance(master_key, str):
|
|
raise Exception(
|
|
"Master key must be a string. Current type - {}".format(
|
|
type(master_key)
|
|
)
|
|
)
|
|
|
|
if master_key is not None and isinstance(master_key, str):
|
|
litellm_master_key_hash = hash_token(master_key)
|
|
### USER API KEY CACHE IN-MEMORY TTL ###
|
|
user_api_key_cache_ttl = general_settings.get(
|
|
"user_api_key_cache_ttl", None
|
|
)
|
|
if user_api_key_cache_ttl is not None:
|
|
user_api_key_cache.update_cache_ttl(
|
|
default_in_memory_ttl=float(user_api_key_cache_ttl),
|
|
default_redis_ttl=None, # user_api_key_cache is an in-memory cache
|
|
)
|
|
### STORE MODEL IN DB ### feature flag for `/model/new`
|
|
store_model_in_db = general_settings.get("store_model_in_db", False)
|
|
if store_model_in_db is None:
|
|
store_model_in_db = False
|
|
### CUSTOM API KEY AUTH ###
|
|
## pass filepath
|
|
custom_auth = general_settings.get("custom_auth", None)
|
|
if custom_auth is not None:
|
|
user_custom_auth = get_instance_fn(
|
|
value=custom_auth, config_file_path=config_file_path
|
|
)
|
|
|
|
custom_key_generate = general_settings.get("custom_key_generate", None)
|
|
if custom_key_generate is not None:
|
|
user_custom_key_generate = get_instance_fn(
|
|
value=custom_key_generate, config_file_path=config_file_path
|
|
)
|
|
## pass through endpoints
|
|
if general_settings.get("pass_through_endpoints", None) is not None:
|
|
await initialize_pass_through_endpoints(
|
|
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
|
)
|
|
## dynamodb
|
|
database_type = general_settings.get("database_type", None)
|
|
if database_type is not None and (
|
|
database_type == "dynamo_db" or database_type == "dynamodb"
|
|
):
|
|
database_args = general_settings.get("database_args", None)
|
|
### LOAD FROM os.environ/ ###
|
|
for k, v in database_args.items():
|
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
|
database_args[k] = litellm.get_secret(v)
|
|
if isinstance(k, str) and k == "aws_web_identity_token":
|
|
value = database_args[k]
|
|
verbose_proxy_logger.debug(
|
|
f"Loading AWS Web Identity Token from file: {value}"
|
|
)
|
|
if os.path.exists(value):
|
|
with open(value, "r") as file:
|
|
token_content = file.read()
|
|
database_args[k] = token_content
|
|
else:
|
|
verbose_proxy_logger.info(
|
|
f"DynamoDB Loading - {value} is not a valid file path"
|
|
)
|
|
verbose_proxy_logger.debug("database_args: %s", database_args)
|
|
custom_db_client = DBClient(
|
|
custom_db_args=database_args, custom_db_type=database_type
|
|
)
|
|
## ADMIN UI ACCESS ##
|
|
ui_access_mode = general_settings.get(
|
|
"ui_access_mode", "all"
|
|
) # can be either ["admin_only" or "all"]
|
|
## BUDGET RESCHEDULER ##
|
|
proxy_budget_rescheduler_min_time = general_settings.get(
|
|
"proxy_budget_rescheduler_min_time", proxy_budget_rescheduler_min_time
|
|
)
|
|
proxy_budget_rescheduler_max_time = general_settings.get(
|
|
"proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time
|
|
)
|
|
## BATCH WRITER ##
|
|
proxy_batch_write_at = general_settings.get(
|
|
"proxy_batch_write_at", proxy_batch_write_at
|
|
)
|
|
## DISABLE SPEND LOGS ## - gives a perf improvement
|
|
disable_spend_logs = general_settings.get(
|
|
"disable_spend_logs", disable_spend_logs
|
|
)
|
|
### BACKGROUND HEALTH CHECKS ###
|
|
# Enable background health checks
|
|
use_background_health_checks = general_settings.get(
|
|
"background_health_checks", False
|
|
)
|
|
health_check_interval = general_settings.get("health_check_interval", 300)
|
|
|
|
## check if user has set a premium feature in general_settings
|
|
if (
|
|
general_settings.get("enforced_params") is not None
|
|
and premium_user is not True
|
|
):
|
|
raise ValueError(
|
|
"Trying to use `enforced_params`"
|
|
+ CommonProxyErrors.not_premium_user.value
|
|
)
|
|
|
|
router_params: dict = {
|
|
"cache_responses": litellm.cache
|
|
!= None, # cache if user passed in cache values
|
|
}
|
|
## MODEL LIST
|
|
model_list = config.get("model_list", None)
|
|
if model_list:
|
|
router_params["model_list"] = model_list
|
|
print( # noqa
|
|
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
|
|
) # noqa
|
|
for model in model_list:
|
|
### LOAD FROM os.environ/ ###
|
|
for k, v in model["litellm_params"].items():
|
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
|
model["litellm_params"][k] = litellm.get_secret(v)
|
|
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
|
litellm_model_name = model["litellm_params"]["model"]
|
|
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
|
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
|
run_ollama_serve()
|
|
|
|
## ASSISTANT SETTINGS
|
|
assistants_config: Optional[AssistantsTypedDict] = None
|
|
assistant_settings = config.get("assistant_settings", None)
|
|
if assistant_settings:
|
|
for k, v in assistant_settings["litellm_params"].items():
|
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
|
_v = v.replace("os.environ/", "")
|
|
v = os.getenv(_v)
|
|
assistant_settings["litellm_params"][k] = v
|
|
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
|
|
|
|
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
|
router_settings = config.get("router_settings", None)
|
|
if router_settings and isinstance(router_settings, dict):
|
|
arg_spec = inspect.getfullargspec(litellm.Router)
|
|
# model list already set
|
|
exclude_args = {
|
|
"self",
|
|
"model_list",
|
|
}
|
|
|
|
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
|
|
|
for k, v in router_settings.items():
|
|
if k in available_args:
|
|
router_params[k] = v
|
|
router = litellm.Router(
|
|
**router_params,
|
|
assistants_config=assistants_config,
|
|
router_general_settings=RouterGeneralSettings(
|
|
async_only_mode=True # only init async clients
|
|
),
|
|
) # type:ignore
|
|
return router, router.get_model_list(), general_settings
|
|
|
|
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
|
"""
|
|
Common logic across add + delete router models
|
|
Parameters:
|
|
- deployment
|
|
- db_model -> flag for differentiating model stored in db vs. config -> used on UI
|
|
|
|
Return model info w/ id
|
|
"""
|
|
_id: Optional[str] = getattr(model, "model_id", None)
|
|
if _id is not None:
|
|
model.model_info["id"] = _id
|
|
model.model_info["db_model"] = True
|
|
|
|
if premium_user is True:
|
|
# seeing "created_at", "updated_at", "created_by", "updated_by" is a LiteLLM Enterprise Feature
|
|
model.model_info["created_at"] = getattr(model, "created_at", None)
|
|
model.model_info["updated_at"] = getattr(model, "updated_at", None)
|
|
model.model_info["created_by"] = getattr(model, "created_by", None)
|
|
model.model_info["updated_by"] = getattr(model, "updated_by", None)
|
|
|
|
if model.model_info is not None and isinstance(model.model_info, dict):
|
|
if "id" not in model.model_info:
|
|
model.model_info["id"] = model.model_id
|
|
if "db_model" in model.model_info and model.model_info["db_model"] == False:
|
|
model.model_info["db_model"] = db_model
|
|
_model_info = RouterModelInfo(**model.model_info)
|
|
|
|
else:
|
|
_model_info = RouterModelInfo(id=model.model_id, db_model=db_model)
|
|
return _model_info
|
|
|
|
async def _delete_deployment(self, db_models: list) -> int:
|
|
"""
|
|
(Helper function of add deployment) -> combined to reduce prisma db calls
|
|
|
|
- Create all up list of model id's (db + config)
|
|
- Compare all up list to router model id's
|
|
- Remove any that are missing
|
|
|
|
Return:
|
|
- int - returns number of deleted deployments
|
|
"""
|
|
global user_config_file_path, llm_router
|
|
combined_id_list = []
|
|
if llm_router is None:
|
|
return 0
|
|
|
|
## DB MODELS ##
|
|
for m in db_models:
|
|
model_info = self.get_model_info_with_id(model=m)
|
|
if model_info.id is not None:
|
|
combined_id_list.append(model_info.id)
|
|
|
|
## CONFIG MODELS ##
|
|
config = await self.get_config(config_file_path=user_config_file_path)
|
|
model_list = config.get("model_list", None)
|
|
if model_list:
|
|
for model in model_list:
|
|
### LOAD FROM os.environ/ ###
|
|
for k, v in model["litellm_params"].items():
|
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
|
model["litellm_params"][k] = litellm.get_secret(v)
|
|
|
|
## check if they have model-id's ##
|
|
model_id = model.get("model_info", {}).get("id", None)
|
|
if model_id is None:
|
|
## else - generate stable id's ##
|
|
model_id = llm_router._generate_model_id(
|
|
model_group=model["model_name"],
|
|
litellm_params=model["litellm_params"],
|
|
)
|
|
combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST
|
|
|
|
router_model_ids = llm_router.get_model_ids()
|
|
# Check for model IDs in llm_router not present in combined_id_list and delete them
|
|
|
|
deleted_deployments = 0
|
|
for model_id in router_model_ids:
|
|
if model_id not in combined_id_list:
|
|
is_deleted = llm_router.delete_deployment(id=model_id)
|
|
if is_deleted is not None:
|
|
deleted_deployments += 1
|
|
return deleted_deployments
|
|
|
|
def _add_deployment(self, db_models: list) -> int:
|
|
"""
|
|
Iterate through db models
|
|
|
|
for any not in router - add them.
|
|
|
|
Return - number of deployments added
|
|
"""
|
|
import base64
|
|
|
|
if master_key is None or not isinstance(master_key, str):
|
|
raise Exception(
|
|
f"Master key is not initialized or formatted. master_key={master_key}"
|
|
)
|
|
|
|
if llm_router is None:
|
|
return 0
|
|
|
|
added_models = 0
|
|
## ADD MODEL LOGIC
|
|
for m in db_models:
|
|
_litellm_params = m.litellm_params
|
|
if isinstance(_litellm_params, dict):
|
|
# decrypt values
|
|
for k, v in _litellm_params.items():
|
|
if isinstance(v, str):
|
|
# decode base64
|
|
try:
|
|
decoded_b64 = base64.b64decode(v)
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"Error decoding value - {}".format(v)
|
|
)
|
|
continue
|
|
# decrypt value
|
|
_value = decrypt_value(value=decoded_b64, master_key=master_key)
|
|
# sanity check if string > size 0
|
|
if len(_value) > 0:
|
|
_litellm_params[k] = _value
|
|
_litellm_params = LiteLLM_Params(**_litellm_params)
|
|
else:
|
|
verbose_proxy_logger.error(
|
|
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
|
)
|
|
continue # skip to next model
|
|
_model_info = self.get_model_info_with_id(
|
|
model=m, db_model=True
|
|
) ## 👈 FLAG = True for db_models
|
|
|
|
added = llm_router.upsert_deployment(
|
|
deployment=Deployment(
|
|
model_name=m.model_name,
|
|
litellm_params=_litellm_params,
|
|
model_info=_model_info,
|
|
)
|
|
)
|
|
|
|
if added is not None:
|
|
added_models += 1
|
|
return added_models
|
|
|
|
async def _update_llm_router(
|
|
self,
|
|
new_models: list,
|
|
proxy_logging_obj: ProxyLogging,
|
|
):
|
|
global llm_router, llm_model_list, master_key, general_settings
|
|
import base64
|
|
|
|
try:
|
|
if llm_router is None and master_key is not None:
|
|
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
|
|
|
_model_list: list = []
|
|
for m in new_models:
|
|
_litellm_params = m.litellm_params
|
|
if isinstance(_litellm_params, dict):
|
|
# decrypt values
|
|
for k, v in _litellm_params.items():
|
|
if isinstance(v, str):
|
|
# decode base64
|
|
decoded_b64 = base64.b64decode(v)
|
|
# decrypt value
|
|
_litellm_params[k] = decrypt_value(
|
|
value=decoded_b64, master_key=master_key # type: ignore
|
|
)
|
|
_litellm_params = LiteLLM_Params(**_litellm_params)
|
|
else:
|
|
verbose_proxy_logger.error(
|
|
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
|
)
|
|
continue # skip to next model
|
|
|
|
_model_info = self.get_model_info_with_id(model=m)
|
|
_model_list.append(
|
|
Deployment(
|
|
model_name=m.model_name,
|
|
litellm_params=_litellm_params,
|
|
model_info=_model_info,
|
|
).to_json(exclude_none=True)
|
|
)
|
|
if len(_model_list) > 0:
|
|
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
|
|
llm_router = litellm.Router(
|
|
model_list=_model_list,
|
|
router_general_settings=RouterGeneralSettings(
|
|
async_only_mode=True # only init async clients
|
|
),
|
|
)
|
|
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
|
|
else:
|
|
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
|
## DELETE MODEL LOGIC
|
|
await self._delete_deployment(db_models=new_models)
|
|
|
|
## ADD MODEL LOGIC
|
|
self._add_deployment(db_models=new_models)
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
f"Error adding/deleting model to llm_router: {str(e)}"
|
|
)
|
|
|
|
if llm_router is not None:
|
|
llm_model_list = llm_router.get_model_list()
|
|
|
|
# check if user set any callbacks in Config Table
|
|
config_data = await proxy_config.get_config()
|
|
litellm_settings = config_data.get("litellm_settings", {}) or {}
|
|
success_callbacks = litellm_settings.get("success_callback", None)
|
|
failure_callbacks = litellm_settings.get("failure_callback", None)
|
|
|
|
if success_callbacks is not None and isinstance(success_callbacks, list):
|
|
for success_callback in success_callbacks:
|
|
if success_callback not in litellm.success_callback:
|
|
litellm.success_callback.append(success_callback)
|
|
|
|
# Add failure callbacks from DB to litellm
|
|
if failure_callbacks is not None and isinstance(failure_callbacks, list):
|
|
for failure_callback in failure_callbacks:
|
|
if failure_callback not in litellm.failure_callback:
|
|
litellm.failure_callback.append(failure_callback)
|
|
# we need to set env variables too
|
|
environment_variables = config_data.get("environment_variables", {})
|
|
for k, v in environment_variables.items():
|
|
try:
|
|
if v is not None:
|
|
decoded_b64 = base64.b64decode(v)
|
|
value = decrypt_value(value=decoded_b64, master_key=master_key) # type: ignore
|
|
os.environ[k] = value
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"Error setting env variable: %s - %s", k, str(e)
|
|
)
|
|
|
|
# router settings
|
|
if llm_router is not None and prisma_client is not None:
|
|
db_router_settings = await prisma_client.db.litellm_config.find_first(
|
|
where={"param_name": "router_settings"}
|
|
)
|
|
if (
|
|
db_router_settings is not None
|
|
and db_router_settings.param_value is not None
|
|
):
|
|
_router_settings = db_router_settings.param_value
|
|
llm_router.update_settings(**_router_settings)
|
|
|
|
## ALERTING ## [TODO] move this to the _update_general_settings() block
|
|
_general_settings = config_data.get("general_settings", {})
|
|
if "alerting" in _general_settings:
|
|
if (
|
|
general_settings is not None
|
|
and general_settings.get("alerting", None) is not None
|
|
and isinstance(general_settings["alerting"], list)
|
|
and _general_settings.get("alerting", None) is not None
|
|
and isinstance(_general_settings["alerting"], list)
|
|
):
|
|
for alert in _general_settings["alerting"]:
|
|
if alert not in general_settings["alerting"]:
|
|
general_settings["alerting"].append(alert)
|
|
|
|
proxy_logging_obj.alerting = general_settings["alerting"]
|
|
proxy_logging_obj.slack_alerting_instance.alerting = general_settings[
|
|
"alerting"
|
|
]
|
|
elif general_settings is None:
|
|
general_settings = {}
|
|
general_settings["alerting"] = _general_settings["alerting"]
|
|
proxy_logging_obj.alerting = general_settings["alerting"]
|
|
proxy_logging_obj.slack_alerting_instance.alerting = general_settings[
|
|
"alerting"
|
|
]
|
|
|
|
if "alert_types" in _general_settings:
|
|
general_settings["alert_types"] = _general_settings["alert_types"]
|
|
proxy_logging_obj.alert_types = general_settings["alert_types"]
|
|
proxy_logging_obj.slack_alerting_instance.update_values(
|
|
alert_types=general_settings["alert_types"], llm_router=llm_router
|
|
)
|
|
|
|
if "alert_to_webhook_url" in _general_settings:
|
|
general_settings["alert_to_webhook_url"] = _general_settings[
|
|
"alert_to_webhook_url"
|
|
]
|
|
proxy_logging_obj.slack_alerting_instance.update_values(
|
|
alert_to_webhook_url=general_settings["alert_to_webhook_url"],
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
async def _update_general_settings(self, db_general_settings: Optional[Json]):
|
|
"""
|
|
Pull from DB, read general settings value
|
|
"""
|
|
global general_settings
|
|
if db_general_settings is None:
|
|
return
|
|
_general_settings = dict(db_general_settings)
|
|
## MAX PARALLEL REQUESTS ##
|
|
if "max_parallel_requests" in _general_settings:
|
|
general_settings["max_parallel_requests"] = _general_settings[
|
|
"max_parallel_requests"
|
|
]
|
|
|
|
if "global_max_parallel_requests" in _general_settings:
|
|
general_settings["global_max_parallel_requests"] = _general_settings[
|
|
"global_max_parallel_requests"
|
|
]
|
|
|
|
## ALERTING ARGS ##
|
|
if "alerting_args" in _general_settings:
|
|
general_settings["alerting_args"] = _general_settings["alerting_args"]
|
|
proxy_logging_obj.slack_alerting_instance.update_values(
|
|
alerting_args=general_settings["alerting_args"],
|
|
)
|
|
|
|
async def add_deployment(
|
|
self,
|
|
prisma_client: PrismaClient,
|
|
proxy_logging_obj: ProxyLogging,
|
|
):
|
|
"""
|
|
- Check db for new models
|
|
- Check if model id's in router already
|
|
- If not, add to router
|
|
"""
|
|
global llm_router, llm_model_list, master_key, general_settings
|
|
|
|
try:
|
|
if master_key is None or not isinstance(master_key, str):
|
|
raise ValueError(
|
|
f"Master key is not initialized or formatted. master_key={master_key}"
|
|
)
|
|
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
|
# update llm router
|
|
await self._update_llm_router(
|
|
new_models=new_models, proxy_logging_obj=proxy_logging_obj
|
|
)
|
|
|
|
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
|
where={"param_name": "general_settings"}
|
|
)
|
|
|
|
# update general settings
|
|
if db_general_settings is not None:
|
|
await self._update_general_settings(
|
|
db_general_settings=db_general_settings.param_value,
|
|
)
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"{}\nTraceback:{}".format(str(e), traceback.format_exc())
|
|
)
|
|
|
|
|
|
proxy_config = ProxyConfig()
|
|
|
|
|
|
def save_worker_config(**data):
|
|
import json
|
|
|
|
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
|
|
|
|
|
async def initialize(
|
|
model=None,
|
|
alias=None,
|
|
api_base=None,
|
|
api_version=None,
|
|
debug=False,
|
|
detailed_debug=False,
|
|
temperature=None,
|
|
max_tokens=None,
|
|
request_timeout=600,
|
|
max_budget=None,
|
|
telemetry=False,
|
|
drop_params=True,
|
|
add_function_to_prompt=True,
|
|
headers=None,
|
|
save=False,
|
|
use_queue=False,
|
|
config=None,
|
|
):
|
|
global user_model, user_api_base, user_debug, user_detailed_debug, user_user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
|
|
generate_feedback_box()
|
|
user_model = model
|
|
user_debug = debug
|
|
if debug == True: # this needs to be first, so users can see Router init debugg
|
|
import logging
|
|
|
|
from litellm._logging import (
|
|
verbose_logger,
|
|
verbose_proxy_logger,
|
|
verbose_router_logger,
|
|
)
|
|
|
|
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
|
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
|
|
verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info
|
|
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
|
if detailed_debug == True:
|
|
import logging
|
|
|
|
from litellm._logging import (
|
|
verbose_logger,
|
|
verbose_proxy_logger,
|
|
verbose_router_logger,
|
|
)
|
|
|
|
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
|
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
|
|
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
|
|
elif debug == False and detailed_debug == False:
|
|
# users can control proxy debugging using env variable = 'LITELLM_LOG'
|
|
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
|
if litellm_log_setting != None:
|
|
if litellm_log_setting.upper() == "INFO":
|
|
import logging
|
|
|
|
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
|
|
|
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
|
|
|
verbose_router_logger.setLevel(
|
|
level=logging.INFO
|
|
) # set router logs to info
|
|
verbose_proxy_logger.setLevel(
|
|
level=logging.INFO
|
|
) # set proxy logs to info
|
|
elif litellm_log_setting.upper() == "DEBUG":
|
|
import logging
|
|
|
|
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
|
|
|
verbose_router_logger.setLevel(
|
|
level=logging.DEBUG
|
|
) # set router logs to info
|
|
verbose_proxy_logger.setLevel(
|
|
level=logging.DEBUG
|
|
) # set proxy logs to debug
|
|
dynamic_config = {"general": {}, user_model: {}}
|
|
if config:
|
|
(
|
|
llm_router,
|
|
llm_model_list,
|
|
general_settings,
|
|
) = await proxy_config.load_config(router=llm_router, config_file_path=config)
|
|
if headers: # model-specific param
|
|
user_headers = headers
|
|
dynamic_config[user_model]["headers"] = headers
|
|
if api_base: # model-specific param
|
|
user_api_base = api_base
|
|
dynamic_config[user_model]["api_base"] = api_base
|
|
if api_version:
|
|
os.environ["AZURE_API_VERSION"] = (
|
|
api_version # set this for azure - litellm can read this from the env
|
|
)
|
|
if max_tokens: # model-specific param
|
|
user_max_tokens = max_tokens
|
|
dynamic_config[user_model]["max_tokens"] = max_tokens
|
|
if temperature: # model-specific param
|
|
user_temperature = temperature
|
|
dynamic_config[user_model]["temperature"] = temperature
|
|
if request_timeout:
|
|
user_request_timeout = request_timeout
|
|
dynamic_config[user_model]["request_timeout"] = request_timeout
|
|
if alias: # model-specific param
|
|
dynamic_config[user_model]["alias"] = alias
|
|
if drop_params == True: # litellm-specific param
|
|
litellm.drop_params = True
|
|
dynamic_config["general"]["drop_params"] = True
|
|
if add_function_to_prompt == True: # litellm-specific param
|
|
litellm.add_function_to_prompt = True
|
|
dynamic_config["general"]["add_function_to_prompt"] = True
|
|
if max_budget: # litellm-specific param
|
|
litellm.max_budget = max_budget
|
|
dynamic_config["general"]["max_budget"] = max_budget
|
|
if experimental:
|
|
pass
|
|
user_telemetry = telemetry
|
|
|
|
|
|
# for streaming
|
|
def data_generator(response):
|
|
verbose_proxy_logger.debug("inside generator")
|
|
for chunk in response:
|
|
verbose_proxy_logger.debug("returned chunk: %s", chunk)
|
|
try:
|
|
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
|
except:
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
|
|
|
|
|
async def async_assistants_data_generator(
|
|
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
|
):
|
|
verbose_proxy_logger.debug("inside generator")
|
|
try:
|
|
start_time = time.time()
|
|
async with response as chunk:
|
|
|
|
### CALL HOOKS ### - modify outgoing data
|
|
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
|
user_api_key_dict=user_api_key_dict, response=chunk
|
|
)
|
|
|
|
# chunk = chunk.model_dump_json(exclude_none=True)
|
|
async for c in chunk:
|
|
c = c.model_dump_json(exclude_none=True)
|
|
try:
|
|
yield f"data: {c}\n\n"
|
|
except Exception as e:
|
|
yield f"data: {str(e)}\n\n"
|
|
|
|
# Streaming is done, yield the [DONE] chunk
|
|
done_message = "[DONE]"
|
|
yield f"data: {done_message}\n\n"
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.async_assistants_data_generator(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict,
|
|
original_exception=e,
|
|
request_data=request_data,
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
|
)
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
if isinstance(e, HTTPException):
|
|
raise e
|
|
else:
|
|
error_traceback = traceback.format_exc()
|
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
|
|
|
proxy_exception = ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
|
yield f"data: {error_returned}\n\n"
|
|
|
|
|
|
async def async_data_generator(
|
|
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
|
):
|
|
verbose_proxy_logger.debug("inside generator")
|
|
try:
|
|
start_time = time.time()
|
|
async for chunk in response:
|
|
|
|
### CALL HOOKS ### - modify outgoing data
|
|
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
|
user_api_key_dict=user_api_key_dict, response=chunk
|
|
)
|
|
|
|
chunk = chunk.model_dump_json(exclude_none=True)
|
|
try:
|
|
yield f"data: {chunk}\n\n"
|
|
except Exception as e:
|
|
yield f"data: {str(e)}\n\n"
|
|
|
|
# Streaming is done, yield the [DONE] chunk
|
|
done_message = "[DONE]"
|
|
yield f"data: {done_message}\n\n"
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}\n{}".format(
|
|
str(e), traceback.format_exc()
|
|
)
|
|
)
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict,
|
|
original_exception=e,
|
|
request_data=request_data,
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
|
)
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
|
|
if isinstance(e, HTTPException):
|
|
raise e
|
|
else:
|
|
error_traceback = traceback.format_exc()
|
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
|
|
|
proxy_exception = ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
|
yield f"data: {error_returned}\n\n"
|
|
|
|
|
|
def select_data_generator(
|
|
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
|
):
|
|
return async_data_generator(
|
|
response=response,
|
|
user_api_key_dict=user_api_key_dict,
|
|
request_data=request_data,
|
|
)
|
|
|
|
|
|
def get_litellm_model_info(model: dict = {}):
|
|
model_info = model.get("model_info", {})
|
|
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
|
try:
|
|
if "azure" in model_to_lookup:
|
|
model_to_lookup = model_info.get("base_model", None)
|
|
litellm_model_info = litellm.get_model_info(model_to_lookup)
|
|
return litellm_model_info
|
|
except:
|
|
# this should not block returning on /model/info
|
|
# if litellm does not have info on the model it should return {}
|
|
return {}
|
|
|
|
|
|
def on_backoff(details):
|
|
# The 'tries' key in the details dictionary contains the number of completed tries
|
|
verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"])
|
|
|
|
|
|
def giveup(e):
|
|
result = not (
|
|
isinstance(e, ProxyException)
|
|
and getattr(e, "message", None) is not None
|
|
and isinstance(e.message, str)
|
|
and "Max parallel request limit reached" in e.message
|
|
)
|
|
if result:
|
|
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
|
|
return result
|
|
|
|
|
|
@router.on_event("startup")
|
|
async def startup_event():
|
|
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
|
|
import json
|
|
|
|
### LOAD MASTER KEY ###
|
|
# check if master key set in environment - load from there
|
|
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
|
# check if DATABASE_URL in environment - load from there
|
|
if prisma_client is None:
|
|
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
|
|
|
|
### LOAD CONFIG ###
|
|
worker_config = litellm.get_secret("WORKER_CONFIG")
|
|
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
|
# check if it's a valid file path
|
|
if os.path.isfile(worker_config):
|
|
if proxy_config.is_yaml(config_file_path=worker_config):
|
|
(
|
|
llm_router,
|
|
llm_model_list,
|
|
general_settings,
|
|
) = await proxy_config.load_config(
|
|
router=llm_router, config_file_path=worker_config
|
|
)
|
|
else:
|
|
await initialize(**worker_config)
|
|
else:
|
|
# if not, assume it's a json string
|
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
|
await initialize(**worker_config)
|
|
|
|
## COST TRACKING ##
|
|
cost_tracking()
|
|
|
|
## Error Tracking ##
|
|
error_tracking()
|
|
|
|
## UPDATE SLACK ALERTING ##
|
|
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
|
|
|
|
db_writer_client = HTTPHandler()
|
|
|
|
## UPDATE INTERNAL USAGE CACHE ##
|
|
proxy_logging_obj.update_values(
|
|
redis_cache=redis_usage_cache
|
|
) # used by parallel request limiter for rate limiting keys across instances
|
|
|
|
proxy_logging_obj._init_litellm_callbacks(
|
|
llm_router=llm_router
|
|
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
|
|
|
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
|
|
asyncio.create_task(
|
|
proxy_logging_obj.slack_alerting_instance._run_scheduled_daily_report(
|
|
llm_router=llm_router
|
|
)
|
|
) # RUN DAILY REPORT (if scheduled)
|
|
|
|
## JWT AUTH ##
|
|
if general_settings.get("litellm_jwtauth", None) is not None:
|
|
for k, v in general_settings["litellm_jwtauth"].items():
|
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
|
general_settings["litellm_jwtauth"][k] = litellm.get_secret(v)
|
|
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
|
else:
|
|
litellm_jwtauth = LiteLLM_JWTAuth()
|
|
jwt_handler.update_environment(
|
|
prisma_client=prisma_client,
|
|
user_api_key_cache=user_api_key_cache,
|
|
litellm_jwtauth=litellm_jwtauth,
|
|
)
|
|
|
|
if use_background_health_checks:
|
|
asyncio.create_task(
|
|
_run_background_health_check()
|
|
) # start the background health check coroutine.
|
|
|
|
if prompt_injection_detection_obj is not None:
|
|
prompt_injection_detection_obj.update_environment(router=llm_router)
|
|
|
|
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
|
|
if prisma_client is not None:
|
|
await prisma_client.connect()
|
|
|
|
verbose_proxy_logger.debug("custom_db_client client - %s", custom_db_client)
|
|
if custom_db_client is not None:
|
|
verbose_proxy_logger.debug("custom_db_client: connecting %s", custom_db_client)
|
|
await custom_db_client.connect()
|
|
|
|
if prisma_client is not None and master_key is not None:
|
|
# add master key to db
|
|
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
|
litellm_proxy_admin_name = os.getenv(
|
|
"PROXY_ADMIN_ID", litellm_proxy_admin_name
|
|
)
|
|
asyncio.create_task(
|
|
generate_key_helper_fn(
|
|
request_type="user",
|
|
duration=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
spend=0,
|
|
token=master_key,
|
|
user_id=litellm_proxy_admin_name,
|
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
|
query_type="update_data",
|
|
update_key_values={"user_role": LitellmUserRoles.PROXY_ADMIN},
|
|
)
|
|
)
|
|
|
|
if prisma_client is not None and litellm.max_budget > 0:
|
|
if litellm.budget_duration is None:
|
|
raise Exception(
|
|
"budget_duration not set on Proxy. budget_duration is required to use max_budget."
|
|
)
|
|
|
|
# add proxy budget to db in the user table
|
|
asyncio.create_task(
|
|
generate_key_helper_fn(
|
|
request_type="user",
|
|
user_id=litellm_proxy_budget_name,
|
|
duration=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
spend=0,
|
|
max_budget=litellm.max_budget,
|
|
budget_duration=litellm.budget_duration,
|
|
query_type="update_data",
|
|
update_key_values={
|
|
"max_budget": litellm.max_budget,
|
|
"budget_duration": litellm.budget_duration,
|
|
},
|
|
)
|
|
)
|
|
|
|
if custom_db_client is not None and master_key is not None:
|
|
# add master key to db
|
|
await generate_key_helper_fn(
|
|
request_type="key",
|
|
duration=None,
|
|
models=[],
|
|
aliases={},
|
|
config={},
|
|
spend=0,
|
|
token=master_key,
|
|
)
|
|
|
|
### CHECK IF VIEW EXISTS ###
|
|
if prisma_client is not None:
|
|
create_view_response = await prisma_client.check_view_exists()
|
|
|
|
### START BATCH WRITING DB + CHECKING NEW MODELS###
|
|
if prisma_client is not None:
|
|
scheduler = AsyncIOScheduler()
|
|
interval = random.randint(
|
|
proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
|
|
) # random interval, so multiple workers avoid resetting budget at the same time
|
|
batch_writing_interval = random.randint(
|
|
proxy_batch_write_at - 3, proxy_batch_write_at + 3
|
|
) # random interval, so multiple workers avoid batch writing at the same time
|
|
|
|
### RESET BUDGET ###
|
|
if general_settings.get("disable_reset_budget", False) == False:
|
|
scheduler.add_job(
|
|
reset_budget, "interval", seconds=interval, args=[prisma_client]
|
|
)
|
|
|
|
### UPDATE SPEND ###
|
|
scheduler.add_job(
|
|
update_spend,
|
|
"interval",
|
|
seconds=batch_writing_interval,
|
|
args=[prisma_client, db_writer_client, proxy_logging_obj],
|
|
)
|
|
|
|
### ADD NEW MODELS ###
|
|
store_model_in_db = (
|
|
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
|
|
or store_model_in_db
|
|
) # type: ignore
|
|
if store_model_in_db == True:
|
|
scheduler.add_job(
|
|
proxy_config.add_deployment,
|
|
"interval",
|
|
seconds=10,
|
|
args=[prisma_client, proxy_logging_obj],
|
|
)
|
|
|
|
# this will load all existing models on proxy startup
|
|
await proxy_config.add_deployment(
|
|
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
|
)
|
|
|
|
if (
|
|
proxy_logging_obj is not None
|
|
and proxy_logging_obj.slack_alerting_instance is not None
|
|
and prisma_client is not None
|
|
):
|
|
print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa
|
|
### Schedule weekly/monhtly spend reports ###
|
|
scheduler.add_job(
|
|
proxy_logging_obj.slack_alerting_instance.send_weekly_spend_report,
|
|
"cron",
|
|
day_of_week="mon",
|
|
)
|
|
|
|
scheduler.add_job(
|
|
proxy_logging_obj.slack_alerting_instance.send_monthly_spend_report,
|
|
"cron",
|
|
day=1,
|
|
)
|
|
|
|
scheduler.start()
|
|
|
|
|
|
#### API ENDPOINTS ####
|
|
@router.get(
|
|
"/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
|
)
|
|
@router.get(
|
|
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
|
) # if project requires model list
|
|
def model_list(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global llm_model_list, general_settings
|
|
all_models = []
|
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
|
if llm_model_list is None:
|
|
proxy_model_list = []
|
|
else:
|
|
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
|
key_models = get_key_models(
|
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
|
)
|
|
team_models = get_team_models(
|
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
|
)
|
|
all_models = get_complete_model_list(
|
|
key_models=key_models,
|
|
team_models=team_models,
|
|
proxy_model_list=proxy_model_list,
|
|
user_model=user_model,
|
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
|
)
|
|
return dict(
|
|
data=[
|
|
{
|
|
"id": model,
|
|
"object": "model",
|
|
"created": 1677610602,
|
|
"owned_by": "openai",
|
|
}
|
|
for model in all_models
|
|
],
|
|
object="list",
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/chat/completions",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["chat/completions"],
|
|
)
|
|
@router.post(
|
|
"/chat/completions",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["chat/completions"],
|
|
)
|
|
@router.post(
|
|
"/engines/{model:path}/chat/completions",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["chat/completions"],
|
|
)
|
|
@router.post(
|
|
"/openai/deployments/{model:path}/chat/completions",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["chat/completions"],
|
|
) # azure compatible endpoint
|
|
@backoff.on_exception(
|
|
backoff.expo,
|
|
Exception, # base exception to catch for the backoff
|
|
max_tries=litellm.num_retries or 3, # maximum number of retries
|
|
max_time=litellm.request_timeout or 60, # maximum total time to retry for
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
|
giveup=giveup,
|
|
logger=verbose_proxy_logger,
|
|
)
|
|
async def chat_completion(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
model: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
|
|
|
data = {}
|
|
try:
|
|
body = await request.body()
|
|
body_str = body.decode()
|
|
try:
|
|
data = ast.literal_eval(body_str)
|
|
except:
|
|
data = json.loads(body_str)
|
|
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
data["model"] = (
|
|
general_settings.get("completion_model", None) # server default
|
|
or user_model # model name passed via cli args
|
|
or model # for azure deployments
|
|
or data["model"] # default passed in http request
|
|
)
|
|
|
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
|
# override with user settings, these are params passed via cli
|
|
if user_temperature:
|
|
data["temperature"] = user_temperature
|
|
if user_request_timeout:
|
|
data["request_timeout"] = user_request_timeout
|
|
if user_max_tokens:
|
|
data["max_tokens"] = user_max_tokens
|
|
if user_api_base:
|
|
data["api_base"] = user_api_base
|
|
|
|
### MODEL ALIAS MAPPING ###
|
|
# check if model name in model alias map
|
|
# get the actual model name
|
|
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
|
|
data["model"] = litellm.model_alias_map[data["model"]]
|
|
|
|
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
|
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
|
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
|
)
|
|
|
|
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
|
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
|
data["litellm_call_id"] = str(uuid.uuid4())
|
|
logging_obj, data = litellm.utils.function_setup(
|
|
original_function="acompletion",
|
|
rules_obj=litellm.utils.Rules(),
|
|
start_time=datetime.now(),
|
|
**data,
|
|
)
|
|
|
|
data["litellm_logging_obj"] = logging_obj
|
|
|
|
tasks = []
|
|
tasks.append(
|
|
proxy_logging_obj.during_call_hook(
|
|
data=data,
|
|
user_api_key_dict=user_api_key_dict,
|
|
call_type="completion",
|
|
)
|
|
)
|
|
|
|
### ROUTE THE REQUEST ###
|
|
# Do not change this - it should be a constant time fetch - ALWAYS
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
# skip router if user passed their key
|
|
if "api_key" in data:
|
|
tasks.append(litellm.acompletion(**data))
|
|
elif "," in data["model"] and llm_router is not None:
|
|
if (
|
|
data.get("fastest_response", None) is not None
|
|
and data["fastest_response"] == True
|
|
):
|
|
tasks.append(llm_router.abatch_completion_fastest_response(**data))
|
|
else:
|
|
_models_csv_string = data.pop("model")
|
|
_models = [model.strip() for model in _models_csv_string.split(",")]
|
|
tasks.append(llm_router.abatch_completion(models=_models, **data))
|
|
elif "user_config" in data:
|
|
# initialize a new router instance. make request using this Router
|
|
router_config = data.pop("user_config")
|
|
user_router = litellm.Router(**router_config)
|
|
tasks.append(user_router.acompletion(**data))
|
|
elif (
|
|
llm_router is not None and data["model"] in router_model_names
|
|
): # model in router model list
|
|
tasks.append(llm_router.acompletion(**data))
|
|
elif (
|
|
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
|
): # model in router model list
|
|
tasks.append(llm_router.acompletion(**data))
|
|
elif (
|
|
llm_router is not None
|
|
and llm_router.model_group_alias is not None
|
|
and data["model"] in llm_router.model_group_alias
|
|
): # model set in model_group_alias
|
|
tasks.append(llm_router.acompletion(**data))
|
|
elif (
|
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
|
|
elif (
|
|
llm_router is not None
|
|
and data["model"] not in router_model_names
|
|
and llm_router.default_deployment is not None
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
tasks.append(llm_router.acompletion(**data))
|
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
tasks.append(litellm.acompletion(**data))
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"error": "chat_completion: Invalid model name passed in model="
|
|
+ data.get("model", "")
|
|
},
|
|
)
|
|
|
|
# wait for call to end
|
|
llm_responses = asyncio.gather(
|
|
*tasks
|
|
) # run the moderation check in parallel to the actual llm api call
|
|
|
|
responses = await llm_responses
|
|
|
|
response = responses[1]
|
|
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
response_cost = hidden_params.get("response_cost", None) or ""
|
|
fastest_response_batch_completion = hidden_params.get(
|
|
"fastest_response_batch_completion", None
|
|
)
|
|
|
|
# Post Call Processing
|
|
if llm_router is not None:
|
|
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
if (
|
|
"stream" in data and data["stream"] == True
|
|
): # use generate_responses to stream responses
|
|
custom_headers = get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
response_cost=response_cost,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
|
)
|
|
selected_data_generator = select_data_generator(
|
|
response=response,
|
|
user_api_key_dict=user_api_key_dict,
|
|
request_data=data,
|
|
)
|
|
return StreamingResponse(
|
|
selected_data_generator,
|
|
media_type="text/event-stream",
|
|
headers=custom_headers,
|
|
)
|
|
|
|
### CALL HOOKS ### - modify outgoing data
|
|
response = await proxy_logging_obj.post_call_success_hook(
|
|
user_api_key_dict=user_api_key_dict, response=response
|
|
)
|
|
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
response_cost=response_cost,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
|
**additional_headers,
|
|
)
|
|
)
|
|
|
|
return response
|
|
except RejectedRequestError as e:
|
|
_data = e.request_data
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict,
|
|
original_exception=e,
|
|
request_data=_data,
|
|
)
|
|
_chat_response = litellm.ModelResponse()
|
|
_chat_response.choices[0].message.content = e.message # type: ignore
|
|
|
|
if data.get("stream", None) is not None and data["stream"] == True:
|
|
_iterator = litellm.utils.ModelResponseIterator(
|
|
model_response=_chat_response, convert_to_delta=True
|
|
)
|
|
_streaming_response = litellm.CustomStreamWrapper(
|
|
completion_stream=_iterator,
|
|
model=data.get("model", ""),
|
|
custom_llm_provider="cached_response",
|
|
logging_obj=data.get("litellm_logging_obj", None),
|
|
)
|
|
selected_data_generator = select_data_generator(
|
|
response=_streaming_response,
|
|
user_api_key_dict=user_api_key_dict,
|
|
request_data=_data,
|
|
)
|
|
|
|
return StreamingResponse(
|
|
selected_data_generator,
|
|
media_type="text/event-stream",
|
|
)
|
|
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
|
_chat_response.usage = _usage # type: ignore
|
|
return _chat_response
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
|
|
get_error_message_str(e=e), traceback.format_exc()
|
|
)
|
|
)
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
litellm_debug_info = getattr(e, "litellm_debug_info", "")
|
|
verbose_proxy_logger.debug(
|
|
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
|
|
e,
|
|
litellm_debug_info,
|
|
)
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", str(e)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]
|
|
)
|
|
@router.post(
|
|
"/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]
|
|
)
|
|
@router.post(
|
|
"/engines/{model:path}/completions",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["completions"],
|
|
)
|
|
@router.post(
|
|
"/openai/deployments/{model:path}/completions",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["completions"],
|
|
)
|
|
async def completion(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
model: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
|
data = {}
|
|
try:
|
|
body = await request.body()
|
|
body_str = body.decode()
|
|
try:
|
|
data = ast.literal_eval(body_str)
|
|
except:
|
|
data = json.loads(body_str)
|
|
|
|
data["model"] = (
|
|
general_settings.get("completion_model", None) # server default
|
|
or user_model # model name passed via cli args
|
|
or model # for azure deployments
|
|
or data["model"] # default passed in http request
|
|
)
|
|
if user_model:
|
|
data["model"] = user_model
|
|
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
# override with user settings, these are params passed via cli
|
|
if user_temperature:
|
|
data["temperature"] = user_temperature
|
|
if user_request_timeout:
|
|
data["request_timeout"] = user_request_timeout
|
|
if user_max_tokens:
|
|
data["max_tokens"] = user_max_tokens
|
|
if user_api_base:
|
|
data["api_base"] = user_api_base
|
|
|
|
### MODEL ALIAS MAPPING ###
|
|
# check if model name in model alias map
|
|
# get the actual model name
|
|
if data["model"] in litellm.model_alias_map:
|
|
data["model"] = litellm.model_alias_map[data["model"]]
|
|
|
|
### CALL HOOKS ### - modify incoming data before calling the model
|
|
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
|
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
|
)
|
|
|
|
### ROUTE THE REQUESTs ###
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
# skip router if user passed their key
|
|
if "api_key" in data:
|
|
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
|
elif (
|
|
llm_router is not None and data["model"] in router_model_names
|
|
): # model in router model list
|
|
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
|
elif (
|
|
llm_router is not None
|
|
and llm_router.model_group_alias is not None
|
|
and data["model"] in llm_router.model_group_alias
|
|
): # model set in model_group_alias
|
|
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
|
elif (
|
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
llm_response = asyncio.create_task(
|
|
llm_router.atext_completion(**data, specific_deployment=True)
|
|
)
|
|
elif (
|
|
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
|
): # model in router model list
|
|
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
|
elif (
|
|
llm_router is not None
|
|
and data["model"] not in router_model_names
|
|
and llm_router.default_deployment is not None
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"error": "completion: Invalid model name passed in model="
|
|
+ data.get("model", "")
|
|
},
|
|
)
|
|
|
|
# Await the llm_response task
|
|
response = await llm_response
|
|
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
response_cost = hidden_params.get("response_cost", None) or ""
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
verbose_proxy_logger.debug("final response: %s", response)
|
|
if (
|
|
"stream" in data and data["stream"] == True
|
|
): # use generate_responses to stream responses
|
|
custom_headers = get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
response_cost=response_cost,
|
|
)
|
|
selected_data_generator = select_data_generator(
|
|
response=response,
|
|
user_api_key_dict=user_api_key_dict,
|
|
request_data=data,
|
|
)
|
|
|
|
return StreamingResponse(
|
|
selected_data_generator,
|
|
media_type="text/event-stream",
|
|
headers=custom_headers,
|
|
)
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
response_cost=response_cost,
|
|
)
|
|
)
|
|
|
|
return response
|
|
except RejectedRequestError as e:
|
|
_data = e.request_data
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict,
|
|
original_exception=e,
|
|
request_data=_data,
|
|
)
|
|
if _data.get("stream", None) is not None and _data["stream"] == True:
|
|
_chat_response = litellm.ModelResponse()
|
|
_usage = litellm.Usage(
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
total_tokens=0,
|
|
)
|
|
_chat_response.usage = _usage # type: ignore
|
|
_chat_response.choices[0].message.content = e.message # type: ignore
|
|
_iterator = litellm.utils.ModelResponseIterator(
|
|
model_response=_chat_response, convert_to_delta=True
|
|
)
|
|
_streaming_response = litellm.TextCompletionStreamWrapper(
|
|
completion_stream=_iterator,
|
|
model=_data.get("model", ""),
|
|
)
|
|
|
|
selected_data_generator = select_data_generator(
|
|
response=_streaming_response,
|
|
user_api_key_dict=user_api_key_dict,
|
|
request_data=data,
|
|
)
|
|
|
|
return StreamingResponse(
|
|
selected_data_generator,
|
|
media_type="text/event-stream",
|
|
headers={},
|
|
)
|
|
else:
|
|
_response = litellm.TextCompletionResponse()
|
|
_response.choices[0].text = e.message
|
|
return _response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/embeddings",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_class=ORJSONResponse,
|
|
tags=["embeddings"],
|
|
)
|
|
@router.post(
|
|
"/embeddings",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_class=ORJSONResponse,
|
|
tags=["embeddings"],
|
|
)
|
|
@router.post(
|
|
"/openai/deployments/{model:path}/embeddings",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_class=ORJSONResponse,
|
|
tags=["embeddings"],
|
|
) # azure compatible endpoint
|
|
async def embeddings(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
model: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global proxy_logging_obj
|
|
data: Any = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
body = await request.body()
|
|
data = orjson.loads(body)
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
data["model"] = (
|
|
general_settings.get("embedding_model", None) # server default
|
|
or user_model # model name passed via cli args
|
|
or model # for azure deployments
|
|
or data["model"] # default passed in http request
|
|
)
|
|
if user_model:
|
|
data["model"] = user_model
|
|
|
|
### MODEL ALIAS MAPPING ###
|
|
# check if model name in model alias map
|
|
# get the actual model name
|
|
if data["model"] in litellm.model_alias_map:
|
|
data["model"] = litellm.model_alias_map[data["model"]]
|
|
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
if (
|
|
"input" in data
|
|
and isinstance(data["input"], list)
|
|
and isinstance(data["input"][0], list)
|
|
and isinstance(data["input"][0][0], int)
|
|
): # check if array of tokens passed in
|
|
# check if non-openai/azure model called - e.g. for langchain integration
|
|
if llm_model_list is not None and data["model"] in router_model_names:
|
|
for m in llm_model_list:
|
|
if m["model_name"] == data["model"] and (
|
|
m["litellm_params"]["model"] in litellm.open_ai_embedding_models
|
|
or m["litellm_params"]["model"].startswith("azure/")
|
|
):
|
|
pass
|
|
else:
|
|
# non-openai/azure embedding model called with token input
|
|
input_list = []
|
|
for i in data["input"]:
|
|
input_list.append(
|
|
litellm.decode(model="gpt-3.5-turbo", tokens=i)
|
|
)
|
|
data["input"] = input_list
|
|
break
|
|
|
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
|
data = await proxy_logging_obj.pre_call_hook(
|
|
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
|
)
|
|
|
|
## ROUTE TO CORRECT ENDPOINT ##
|
|
# skip router if user passed their key
|
|
if "api_key" in data:
|
|
response = await litellm.aembedding(**data)
|
|
elif "user_config" in data:
|
|
# initialize a new router instance. make request using this Router
|
|
router_config = data.pop("user_config")
|
|
user_router = litellm.Router(**router_config)
|
|
response = await user_router.aembedding(**data)
|
|
elif (
|
|
llm_router is not None and data["model"] in router_model_names
|
|
): # model in router model list
|
|
response = await llm_router.aembedding(**data)
|
|
elif (
|
|
llm_router is not None
|
|
and llm_router.model_group_alias is not None
|
|
and data["model"] in llm_router.model_group_alias
|
|
): # model set in model_group_alias
|
|
response = await llm_router.aembedding(
|
|
**data
|
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
elif (
|
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.aembedding(**data, specific_deployment=True)
|
|
elif (
|
|
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.aembedding(**data)
|
|
elif (
|
|
llm_router is not None
|
|
and data["model"] not in router_model_names
|
|
and llm_router.default_deployment is not None
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.aembedding(**data)
|
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
response = await litellm.aembedding(**data)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"error": "embeddings: Invalid model name passed in model="
|
|
+ data.get("model", "")
|
|
},
|
|
)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
response_cost = hidden_params.get("response_cost", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
response_cost=response_cost,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
litellm_debug_info = getattr(e, "litellm_debug_info", "")
|
|
verbose_proxy_logger.debug(
|
|
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
|
|
e,
|
|
litellm_debug_info,
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
message = get_error_message_str(e)
|
|
raise ProxyException(
|
|
message=message,
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/images/generations",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_class=ORJSONResponse,
|
|
tags=["images"],
|
|
)
|
|
@router.post(
|
|
"/images/generations",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_class=ORJSONResponse,
|
|
tags=["images"],
|
|
)
|
|
async def image_generation(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global proxy_logging_obj
|
|
data = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
body = await request.body()
|
|
data = orjson.loads(body)
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
data["model"] = (
|
|
general_settings.get("image_generation_model", None) # server default
|
|
or user_model # model name passed via cli args
|
|
or data["model"] # default passed in http request
|
|
)
|
|
if user_model:
|
|
data["model"] = user_model
|
|
|
|
### MODEL ALIAS MAPPING ###
|
|
# check if model name in model alias map
|
|
# get the actual model name
|
|
if data["model"] in litellm.model_alias_map:
|
|
data["model"] = litellm.model_alias_map[data["model"]]
|
|
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
|
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
|
data = await proxy_logging_obj.pre_call_hook(
|
|
user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
|
|
)
|
|
|
|
## ROUTE TO CORRECT ENDPOINT ##
|
|
# skip router if user passed their key
|
|
if "api_key" in data:
|
|
response = await litellm.aimage_generation(**data)
|
|
elif (
|
|
llm_router is not None and data["model"] in router_model_names
|
|
): # model in router model list
|
|
response = await llm_router.aimage_generation(**data)
|
|
elif (
|
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.aimage_generation(
|
|
**data, specific_deployment=True
|
|
)
|
|
elif (
|
|
llm_router is not None
|
|
and llm_router.model_group_alias is not None
|
|
and data["model"] in llm_router.model_group_alias
|
|
): # model set in model_group_alias
|
|
response = await llm_router.aimage_generation(
|
|
**data
|
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
elif (
|
|
llm_router is not None
|
|
and data["model"] not in router_model_names
|
|
and llm_router.default_deployment is not None
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.aimage_generation(**data)
|
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
response = await litellm.aimage_generation(**data)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"error": "image_generation: Invalid model name passed in model="
|
|
+ data.get("model", "")
|
|
},
|
|
)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
response_cost = hidden_params.get("response_cost", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
response_cost=response_cost,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.image_generation(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/audio/speech",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["audio"],
|
|
)
|
|
@router.post(
|
|
"/audio/speech",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["audio"],
|
|
)
|
|
async def audio_speech(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Same params as:
|
|
|
|
https://platform.openai.com/docs/api-reference/audio/createSpeech
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
body = await request.body()
|
|
data = orjson.loads(body)
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
|
data["user"] = user_api_key_dict.user_id
|
|
|
|
if user_model:
|
|
data["model"] = user_model
|
|
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
|
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
|
data = await proxy_logging_obj.pre_call_hook(
|
|
user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
|
|
)
|
|
|
|
## ROUTE TO CORRECT ENDPOINT ##
|
|
# skip router if user passed their key
|
|
if "api_key" in data:
|
|
response = await litellm.aspeech(**data)
|
|
elif (
|
|
llm_router is not None and data["model"] in router_model_names
|
|
): # model in router model list
|
|
response = await llm_router.aspeech(**data)
|
|
elif (
|
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.aspeech(**data, specific_deployment=True)
|
|
elif (
|
|
llm_router is not None
|
|
and llm_router.model_group_alias is not None
|
|
and data["model"] in llm_router.model_group_alias
|
|
): # model set in model_group_alias
|
|
response = await llm_router.aspeech(
|
|
**data
|
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
elif (
|
|
llm_router is not None
|
|
and data["model"] not in router_model_names
|
|
and llm_router.default_deployment is not None
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.aspeech(**data)
|
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
response = await litellm.aspeech(**data)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"error": "audio_speech: Invalid model name passed in model="
|
|
+ data.get("model", "")
|
|
},
|
|
)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
response_cost = hidden_params.get("response_cost", None) or ""
|
|
|
|
# Printing each chunk size
|
|
async def generate(_response: HttpxBinaryResponseContent):
|
|
_generator = await _response.aiter_bytes(chunk_size=1024)
|
|
async for chunk in _generator:
|
|
yield chunk
|
|
|
|
custom_headers = get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
response_cost=response_cost,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
fastest_response_batch_completion=None,
|
|
)
|
|
|
|
selected_data_generator = select_data_generator(
|
|
response=response,
|
|
user_api_key_dict=user_api_key_dict,
|
|
request_data=data,
|
|
)
|
|
return StreamingResponse(
|
|
generate(response), media_type="audio/mpeg", headers=custom_headers
|
|
)
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.audio_speech(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
raise e
|
|
|
|
|
|
@router.post(
|
|
"/v1/audio/transcriptions",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["audio"],
|
|
)
|
|
@router.post(
|
|
"/audio/transcriptions",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["audio"],
|
|
)
|
|
async def audio_transcriptions(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
file: UploadFile = File(...),
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Same params as:
|
|
|
|
https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=curl
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
form_data = await request.form()
|
|
data = {key: value for key, value in form_data.items() if key != "file"}
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
|
data["user"] = user_api_key_dict.user_id
|
|
|
|
data["model"] = (
|
|
general_settings.get("moderation_model", None) # server default
|
|
or user_model # model name passed via cli args
|
|
or data["model"] # default passed in http request
|
|
)
|
|
if user_model:
|
|
data["model"] = user_model
|
|
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
|
|
assert (
|
|
file.filename is not None
|
|
) # make sure filename passed in (needed for type)
|
|
|
|
_original_filename = file.filename
|
|
file_extension = os.path.splitext(file.filename)[1]
|
|
# rename the file to a random hash file name -> we eventuall remove the file and don't want to remove any local files
|
|
file.filename = f"tmp-request" + str(uuid.uuid4()) + file_extension
|
|
|
|
# IMP - Asserts that we've renamed the uploaded file, since we run os.remove(file.filename), we should rename the original file
|
|
assert file.filename != _original_filename
|
|
|
|
with open(file.filename, "wb+") as f:
|
|
f.write(await file.read())
|
|
try:
|
|
data["file"] = open(file.filename, "rb")
|
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
|
data = await proxy_logging_obj.pre_call_hook(
|
|
user_api_key_dict=user_api_key_dict,
|
|
data=data,
|
|
call_type="audio_transcription",
|
|
)
|
|
|
|
## ROUTE TO CORRECT ENDPOINT ##
|
|
# skip router if user passed their key
|
|
if "api_key" in data:
|
|
response = await litellm.atranscription(**data)
|
|
elif (
|
|
llm_router is not None and data["model"] in router_model_names
|
|
): # model in router model list
|
|
response = await llm_router.atranscription(**data)
|
|
|
|
elif (
|
|
llm_router is not None
|
|
and data["model"] in llm_router.deployment_names
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.atranscription(
|
|
**data, specific_deployment=True
|
|
)
|
|
elif (
|
|
llm_router is not None
|
|
and llm_router.model_group_alias is not None
|
|
and data["model"] in llm_router.model_group_alias
|
|
): # model set in model_group_alias
|
|
response = await llm_router.atranscription(
|
|
**data
|
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
elif (
|
|
llm_router is not None
|
|
and data["model"] not in router_model_names
|
|
and llm_router.default_deployment is not None
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.atranscription(**data)
|
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
response = await litellm.atranscription(**data)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"error": "audio_transcriptions: Invalid model name passed in model="
|
|
+ data.get("model", "")
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
finally:
|
|
os.remove(file.filename) # Delete the saved file
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
response_cost = hidden_params.get("response_cost", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
response_cost=response_cost,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.audio_transcription(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
######################################################################
|
|
|
|
# /v1/assistant Endpoints
|
|
|
|
|
|
######################################################################
|
|
|
|
|
|
@router.get(
|
|
"/v1/assistants",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
@router.get(
|
|
"/assistants",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
async def get_assistants(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Returns a list of assistants.
|
|
|
|
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/listAssistants
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
body = await request.body()
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
|
)
|
|
response = await llm_router.aget_assistants(**data)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.get_assistants(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/threads",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
@router.post(
|
|
"/threads",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
async def create_threads(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Create a thread.
|
|
|
|
API Reference - https://platform.openai.com/docs/api-reference/threads/createThread
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
body = await request.body()
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
|
)
|
|
response = await llm_router.acreate_thread(**data)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.create_threads(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/v1/threads/{thread_id}",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
@router.get(
|
|
"/threads/{thread_id}",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
async def get_thread(
|
|
request: Request,
|
|
thread_id: str,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Retrieves a thread.
|
|
|
|
API Reference - https://platform.openai.com/docs/api-reference/threads/getThread
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
|
)
|
|
response = await llm_router.aget_thread(thread_id=thread_id, **data)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.get_thread(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/threads/{thread_id}/messages",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
@router.post(
|
|
"/threads/{thread_id}/messages",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
async def add_messages(
|
|
request: Request,
|
|
thread_id: str,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Create a message.
|
|
|
|
API Reference - https://platform.openai.com/docs/api-reference/messages/createMessage
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
body = await request.body()
|
|
data = orjson.loads(body)
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
|
)
|
|
response = await llm_router.a_add_message(thread_id=thread_id, **data)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.add_messages(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/v1/threads/{thread_id}/messages",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
@router.get(
|
|
"/threads/{thread_id}/messages",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
async def get_messages(
|
|
request: Request,
|
|
thread_id: str,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Returns a list of messages for a given thread.
|
|
|
|
API Reference - https://platform.openai.com/docs/api-reference/messages/listMessages
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
|
)
|
|
response = await llm_router.aget_messages(thread_id=thread_id, **data)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.get_messages(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/threads/{thread_id}/runs",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
@router.post(
|
|
"/threads/{thread_id}/runs",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["assistants"],
|
|
)
|
|
async def run_thread(
|
|
request: Request,
|
|
thread_id: str,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Create a run.
|
|
|
|
API Reference: https://platform.openai.com/docs/api-reference/runs/createRun
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
body = await request.body()
|
|
data = orjson.loads(body)
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
|
)
|
|
response = await llm_router.arun_thread(thread_id=thread_id, **data)
|
|
|
|
if (
|
|
"stream" in data and data["stream"] == True
|
|
): # use generate_responses to stream responses
|
|
return StreamingResponse(
|
|
async_assistants_data_generator(
|
|
user_api_key_dict=user_api_key_dict,
|
|
response=response,
|
|
request_data=data,
|
|
),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.run_thread(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
######################################################################
|
|
|
|
# /v1/batches Endpoints
|
|
|
|
|
|
######################################################################
|
|
@router.post(
|
|
"/v1/batches",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["batch"],
|
|
)
|
|
@router.post(
|
|
"/batches",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["batch"],
|
|
)
|
|
async def create_batch(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Create large batches of API requests for asynchronous processing.
|
|
This is the equivalent of POST https://api.openai.com/v1/batch
|
|
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch
|
|
|
|
Example Curl
|
|
```
|
|
curl http://localhost:4000/v1/batches \
|
|
-H "Authorization: Bearer sk-1234" \
|
|
-H "Content-Type: application/json" \
|
|
-d '{
|
|
"input_file_id": "file-abc123",
|
|
"endpoint": "/v1/chat/completions",
|
|
"completion_window": "24h"
|
|
}'
|
|
```
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
form_data = await request.form()
|
|
data = {key: value for key, value in form_data.items() if key != "file"}
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
_create_batch_data = CreateBatchRequest(**data)
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
response = await litellm.acreate_batch(
|
|
custom_llm_provider="openai", **_create_batch_data
|
|
)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/v1/batches{batch_id}",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["batch"],
|
|
)
|
|
@router.get(
|
|
"/batches{batch_id}",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["batch"],
|
|
)
|
|
async def retrieve_batch(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
batch_id: str = Path(
|
|
title="Batch ID to retrieve", description="The ID of the batch to retrieve"
|
|
),
|
|
):
|
|
"""
|
|
Retrieves a batch.
|
|
This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id}
|
|
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve
|
|
|
|
Example Curl
|
|
```
|
|
curl http://localhost:4000/v1/batches/batch_abc123 \
|
|
-H "Authorization: Bearer sk-1234" \
|
|
-H "Content-Type: application/json" \
|
|
|
|
```
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
form_data = await request.form()
|
|
data = {key: value for key, value in form_data.items() if key != "file"}
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
_retrieve_batch_request = RetrieveBatchRequest(
|
|
batch_id=batch_id,
|
|
)
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
response = await litellm.aretrieve_batch(
|
|
custom_llm_provider="openai", **_retrieve_batch_request
|
|
)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_traceback = traceback.format_exc()
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
######################################################################
|
|
|
|
# END OF /v1/batches Endpoints Implementation
|
|
|
|
######################################################################
|
|
|
|
######################################################################
|
|
|
|
# /v1/files Endpoints
|
|
|
|
|
|
######################################################################
|
|
@router.post(
|
|
"/v1/files",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["files"],
|
|
)
|
|
@router.post(
|
|
"/files",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
tags=["files"],
|
|
)
|
|
async def create_file(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Upload a file that can be used across - Assistants API, Batch API
|
|
This is the equivalent of POST https://api.openai.com/v1/files
|
|
|
|
Supports Identical Params as: https://platform.openai.com/docs/api-reference/files/create
|
|
|
|
Example Curl
|
|
```
|
|
curl https://api.openai.com/v1/files \
|
|
-H "Authorization: Bearer sk-1234" \
|
|
-F purpose="batch" \
|
|
-F file="@mydata.jsonl"
|
|
|
|
```
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
form_data = await request.form()
|
|
data = {key: value for key, value in form_data.items() if key != "file"}
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
_create_file_request = CreateFileRequest()
|
|
|
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
response = await litellm.acreate_file(
|
|
custom_llm_provider="openai", **_create_file_request
|
|
)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e.detail)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/v1/moderations",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_class=ORJSONResponse,
|
|
tags=["moderations"],
|
|
)
|
|
@router.post(
|
|
"/moderations",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_class=ORJSONResponse,
|
|
tags=["moderations"],
|
|
)
|
|
async def moderations(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies.
|
|
|
|
Quick Start
|
|
```
|
|
curl --location 'http://0.0.0.0:4000/moderations' \
|
|
--header 'Content-Type: application/json' \
|
|
--header 'Authorization: Bearer sk-1234' \
|
|
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
|
|
```
|
|
"""
|
|
global proxy_logging_obj
|
|
data: Dict = {}
|
|
try:
|
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
|
body = await request.body()
|
|
data = orjson.loads(body)
|
|
|
|
# Include original request and headers in the data
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
general_settings=general_settings,
|
|
user_api_key_dict=user_api_key_dict,
|
|
version=version,
|
|
proxy_config=proxy_config,
|
|
)
|
|
|
|
data["model"] = (
|
|
general_settings.get("moderation_model", None) # server default
|
|
or user_model # model name passed via cli args
|
|
or data.get("model") # default passed in http request
|
|
)
|
|
if user_model:
|
|
data["model"] = user_model
|
|
|
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
|
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
|
data = await proxy_logging_obj.pre_call_hook(
|
|
user_api_key_dict=user_api_key_dict, data=data, call_type="moderation"
|
|
)
|
|
|
|
start_time = time.time()
|
|
|
|
## ROUTE TO CORRECT ENDPOINT ##
|
|
# skip router if user passed their key
|
|
if "api_key" in data:
|
|
response = await litellm.amoderation(**data)
|
|
elif (
|
|
llm_router is not None and data.get("model") in router_model_names
|
|
): # model in router model list
|
|
response = await llm_router.amoderation(**data)
|
|
elif (
|
|
llm_router is not None and data.get("model") in llm_router.deployment_names
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.amoderation(**data, specific_deployment=True)
|
|
elif (
|
|
llm_router is not None
|
|
and llm_router.model_group_alias is not None
|
|
and data.get("model") in llm_router.model_group_alias
|
|
): # model set in model_group_alias
|
|
response = await llm_router.amoderation(
|
|
**data
|
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
|
elif (
|
|
llm_router is not None
|
|
and data.get("model") not in router_model_names
|
|
and llm_router.default_deployment is not None
|
|
): # model in router deployments, calling a specific deployment on the router
|
|
response = await llm_router.amoderation(**data)
|
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
response = await litellm.amoderation(**data)
|
|
else:
|
|
# /moderations does not need a "model" passed
|
|
# see https://platform.openai.com/docs/api-reference/moderations
|
|
response = await litellm.amoderation(**data)
|
|
|
|
### ALERTING ###
|
|
asyncio.create_task(
|
|
proxy_logging_obj.update_request_status(
|
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
)
|
|
)
|
|
|
|
### RESPONSE HEADERS ###
|
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
model_id = hidden_params.get("model_id", None) or ""
|
|
cache_key = hidden_params.get("cache_key", None) or ""
|
|
api_base = hidden_params.get("api_base", None) or ""
|
|
|
|
fastapi_response.headers.update(
|
|
get_custom_headers(
|
|
user_api_key_dict=user_api_key_dict,
|
|
model_id=model_id,
|
|
cache_key=cache_key,
|
|
api_base=api_base,
|
|
version=version,
|
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
|
)
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "message", str(e)),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
else:
|
|
error_msg = f"{str(e)}"
|
|
raise ProxyException(
|
|
message=getattr(e, "message", error_msg),
|
|
type=getattr(e, "type", "None"),
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", 500),
|
|
)
|
|
|
|
|
|
#### DEV UTILS ####
|
|
|
|
# @router.get(
|
|
# "/utils/available_routes",
|
|
# tags=["llm utils"],
|
|
# dependencies=[Depends(user_api_key_auth)],
|
|
# )
|
|
# async def get_available_routes(user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
|
|
|
|
|
@router.post(
|
|
"/utils/token_counter",
|
|
tags=["llm utils"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=TokenCountResponse,
|
|
)
|
|
async def token_counter(request: TokenCountRequest):
|
|
""" """
|
|
from litellm import token_counter
|
|
|
|
global llm_router
|
|
|
|
prompt = request.prompt
|
|
messages = request.messages
|
|
if prompt is None and messages is None:
|
|
raise HTTPException(
|
|
status_code=400, detail="prompt or messages must be provided"
|
|
)
|
|
|
|
deployment = None
|
|
litellm_model_name = None
|
|
if llm_router is not None:
|
|
# get 1 deployment corresponding to the model
|
|
for _model in llm_router.model_list:
|
|
if _model["model_name"] == request.model:
|
|
deployment = _model
|
|
break
|
|
if deployment is not None:
|
|
litellm_model_name = deployment.get("litellm_params", {}).get("model")
|
|
# remove the custom_llm_provider_prefix in the litellm_model_name
|
|
if "/" in litellm_model_name:
|
|
litellm_model_name = litellm_model_name.split("/", 1)[1]
|
|
|
|
model_to_use = (
|
|
litellm_model_name or request.model
|
|
) # use litellm model name, if it's not avalable then fallback to request.model
|
|
_tokenizer_used = litellm.utils._select_tokenizer(model=model_to_use)
|
|
tokenizer_used = str(_tokenizer_used["type"])
|
|
total_tokens = token_counter(
|
|
model=model_to_use,
|
|
text=prompt,
|
|
messages=messages,
|
|
)
|
|
return TokenCountResponse(
|
|
total_tokens=total_tokens,
|
|
request_model=request.model,
|
|
model_used=model_to_use,
|
|
tokenizer_type=tokenizer_used,
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/utils/supported_openai_params",
|
|
tags=["llm utils"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def supported_openai_params(model: str):
|
|
"""
|
|
Returns supported openai params for a given litellm model name
|
|
|
|
e.g. `gpt-4` vs `gpt-3.5-turbo`
|
|
|
|
Example curl:
|
|
```
|
|
curl -X GET --location 'http://localhost:4000/utils/supported_openai_params?model=gpt-3.5-turbo-16k' \
|
|
--header 'Authorization: Bearer sk-1234'
|
|
```
|
|
"""
|
|
try:
|
|
model, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
|
|
return {
|
|
"supported_openai_params": litellm.get_supported_openai_params(
|
|
model=model, custom_llm_provider=custom_llm_provider
|
|
)
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=400, detail={"error": "Could not map model={}".format(model)}
|
|
)
|
|
|
|
|
|
#### END-USER MANAGEMENT ####
|
|
|
|
|
|
@router.post(
|
|
"/end_user/block",
|
|
tags=["Customer Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
@router.post(
|
|
"/customer/block",
|
|
tags=["Customer Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def block_user(data: BlockUsers):
|
|
"""
|
|
[BETA] Reject calls with this end-user id
|
|
|
|
(any /chat/completion call with this user={end-user-id} param, will be rejected.)
|
|
|
|
```
|
|
curl -X POST "http://0.0.0.0:8000/user/block"
|
|
-H "Authorization: Bearer sk-1234"
|
|
-D '{
|
|
"user_ids": [<user_id>, ...]
|
|
}'
|
|
```
|
|
"""
|
|
try:
|
|
records = []
|
|
if prisma_client is not None:
|
|
for id in data.user_ids:
|
|
record = await prisma_client.db.litellm_endusertable.upsert(
|
|
where={"user_id": id}, # type: ignore
|
|
data={
|
|
"create": {"user_id": id, "blocked": True}, # type: ignore
|
|
"update": {"blocked": True},
|
|
},
|
|
)
|
|
records.append(record)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"error": "Postgres DB Not connected"},
|
|
)
|
|
|
|
return {"blocked_users": records}
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(f"An error occurred - {str(e)}")
|
|
raise HTTPException(status_code=500, detail={"error": str(e)})
|
|
|
|
|
|
@router.post(
|
|
"/end_user/unblock",
|
|
tags=["Customer Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
@router.post(
|
|
"/customer/unblock",
|
|
tags=["Customer Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def unblock_user(data: BlockUsers):
|
|
"""
|
|
[BETA] Unblock calls with this user id
|
|
|
|
Example
|
|
```
|
|
curl -X POST "http://0.0.0.0:8000/user/unblock"
|
|
-H "Authorization: Bearer sk-1234"
|
|
-D '{
|
|
"user_ids": [<user_id>, ...]
|
|
}'
|
|
```
|
|
"""
|
|
from enterprise.enterprise_hooks.blocked_user_list import (
|
|
_ENTERPRISE_BlockedUserList,
|
|
)
|
|
|
|
if (
|
|
not any(isinstance(x, _ENTERPRISE_BlockedUserList) for x in litellm.callbacks)
|
|
or litellm.blocked_user_list is None
|
|
):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "Blocked user check was never set. This call has no effect."
|
|
},
|
|
)
|
|
|
|
if isinstance(litellm.blocked_user_list, list):
|
|
for id in data.user_ids:
|
|
litellm.blocked_user_list.remove(id)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "`blocked_user_list` must be set as a list. Filepaths can't be updated."
|
|
},
|
|
)
|
|
|
|
return {"blocked_users": litellm.blocked_user_list}
|
|
|
|
|
|
@router.post(
|
|
"/end_user/new",
|
|
tags=["Customer Management"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
@router.post(
|
|
"/customer/new",
|
|
tags=["Customer Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def new_end_user(
|
|
data: NewCustomerRequest,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Allow creating a new Customer
|
|
NOTE: This used to be called `/end_user/new`, we will still be maintaining compatibility for /end_user/XXX for these endpoints
|
|
|
|
- Allow specifying allowed regions
|
|
- Allow specifying default model
|
|
|
|
Example curl:
|
|
```
|
|
curl --location 'http://0.0.0.0:4000/customer/new' \
|
|
--header 'Authorization: Bearer sk-1234' \
|
|
--header 'Content-Type: application/json' \
|
|
--data '{
|
|
"user_id" : "ishaan-jaff-3",
|
|
"allowed_region": "eu",
|
|
"budget_id": "free_tier",
|
|
"default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model?
|
|
}'
|
|
|
|
# return end-user object
|
|
```
|
|
"""
|
|
global prisma_client, llm_router
|
|
"""
|
|
Validation:
|
|
- check if default model exists
|
|
- create budget object if not already created
|
|
|
|
- Add user to end user table
|
|
|
|
Return
|
|
- end-user object
|
|
- currently allowed models
|
|
"""
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
try:
|
|
|
|
## VALIDATION ##
|
|
if data.default_model is not None:
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail={"error": CommonProxyErrors.no_llm_router.value},
|
|
)
|
|
elif data.default_model not in llm_router.get_model_names():
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail={
|
|
"error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format(
|
|
data.default_model, set(llm_router.get_model_names())
|
|
)
|
|
},
|
|
)
|
|
|
|
new_end_user_obj: Dict = {}
|
|
|
|
## CREATE BUDGET ## if set
|
|
if data.max_budget is not None:
|
|
budget_record = await prisma_client.db.litellm_budgettable.create(
|
|
data={
|
|
"max_budget": data.max_budget,
|
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
|
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
}
|
|
)
|
|
|
|
new_end_user_obj["budget_id"] = budget_record.budget_id
|
|
elif data.budget_id is not None:
|
|
new_end_user_obj["budget_id"] = data.budget_id
|
|
|
|
_user_data = data.dict(exclude_none=True)
|
|
|
|
for k, v in _user_data.items():
|
|
if k != "max_budget" and k != "budget_id":
|
|
new_end_user_obj[k] = v
|
|
|
|
## WRITE TO DB ##
|
|
end_user_record = await prisma_client.db.litellm_endusertable.create(
|
|
data=new_end_user_obj # type: ignore
|
|
)
|
|
|
|
return end_user_record
|
|
except Exception as e:
|
|
if "Unique constraint failed on the fields: (`user_id`)" in str(e):
|
|
raise ProxyException(
|
|
message=f"Customer already exists, passed user_id={data.user_id}. Please pass a new user_id.",
|
|
type="bad_request",
|
|
code=400,
|
|
param="user_id",
|
|
)
|
|
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
|
type="internal_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Internal Server Error, " + str(e),
|
|
type="internal_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/customer/info",
|
|
tags=["Customer Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=LiteLLM_EndUserTable,
|
|
)
|
|
@router.get(
|
|
"/end_user/info",
|
|
tags=["Customer Management"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def end_user_info(
|
|
end_user_id: str = fastapi.Query(
|
|
description="End User ID in the request parameters"
|
|
),
|
|
):
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
user_info = await prisma_client.db.litellm_endusertable.find_first(
|
|
where={"user_id": end_user_id}
|
|
)
|
|
|
|
if user_info is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "End User Id={} does not exist in db".format(end_user_id)},
|
|
)
|
|
return user_info.model_dump(exclude_none=True)
|
|
|
|
|
|
@router.post(
|
|
"/customer/update",
|
|
tags=["Customer Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
@router.post(
|
|
"/end_user/update",
|
|
tags=["Customer Management"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def update_end_user(
|
|
data: UpdateCustomerRequest,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Example curl
|
|
|
|
```
|
|
curl --location 'http://0.0.0.0:4000/customer/update' \
|
|
--header 'Authorization: Bearer sk-1234' \
|
|
--header 'Content-Type: application/json' \
|
|
--data '{
|
|
"user_id": "test-litellm-user-4",
|
|
"budget_id": "paid_tier"
|
|
}'
|
|
|
|
See below for all params
|
|
```
|
|
"""
|
|
|
|
global prisma_client
|
|
try:
|
|
data_json: dict = data.json()
|
|
# get the row from db
|
|
if prisma_client is None:
|
|
raise Exception("Not connected to DB!")
|
|
|
|
# get non default values for key
|
|
non_default_values = {}
|
|
for k, v in data_json.items():
|
|
if v is not None and v not in (
|
|
[],
|
|
{},
|
|
0,
|
|
): # models default to [], spend defaults to 0, we should not reset these values
|
|
non_default_values[k] = v
|
|
|
|
## ADD USER, IF NEW ##
|
|
verbose_proxy_logger.debug("/customer/update: Received data = %s", data)
|
|
if data.user_id is not None and len(data.user_id) > 0:
|
|
non_default_values["user_id"] = data.user_id # type: ignore
|
|
verbose_proxy_logger.debug("In update customer, user_id condition block.")
|
|
response = await prisma_client.db.litellm_endusertable.update(
|
|
where={"user_id": data.user_id}, data=non_default_values # type: ignore
|
|
)
|
|
if response is None:
|
|
raise ValueError(
|
|
f"Failed updating customer data. User ID does not exist passed user_id={data.user_id}"
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"received response from updating prisma client. response={response}"
|
|
)
|
|
return response
|
|
else:
|
|
raise ValueError(f"user_id is required, passed user_id = {data.user_id}")
|
|
|
|
# update based on remaining passed in values
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.update_end_user(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
|
type="internal_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Internal Server Error, " + str(e),
|
|
type="internal_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
pass
|
|
|
|
|
|
@router.post(
|
|
"/customer/delete",
|
|
tags=["Customer Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
@router.post(
|
|
"/end_user/delete",
|
|
tags=["Customer Management"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def delete_end_user(
|
|
data: DeleteCustomerRequest,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Example curl
|
|
|
|
```
|
|
curl --location 'http://0.0.0.0:4000/customer/delete' \
|
|
--header 'Authorization: Bearer sk-1234' \
|
|
--header 'Content-Type: application/json' \
|
|
--data '{
|
|
"user_ids" :["ishaan-jaff-5"]
|
|
}'
|
|
|
|
See below for all params
|
|
```
|
|
"""
|
|
global prisma_client
|
|
|
|
try:
|
|
if prisma_client is None:
|
|
raise Exception("Not connected to DB!")
|
|
|
|
verbose_proxy_logger.debug("/customer/delete: Received data = %s", data)
|
|
if (
|
|
data.user_ids is not None
|
|
and isinstance(data.user_ids, list)
|
|
and len(data.user_ids) > 0
|
|
):
|
|
response = await prisma_client.db.litellm_endusertable.delete_many(
|
|
where={"user_id": {"in": data.user_ids}}
|
|
)
|
|
if response is None:
|
|
raise ValueError(
|
|
f"Failed deleting customer data. User ID does not exist passed user_id={data.user_ids}"
|
|
)
|
|
if response != len(data.user_ids):
|
|
raise ValueError(
|
|
f"Failed deleting all customer data. User ID does not exist passed user_id={data.user_ids}. Deleted {response} customers, passed {len(data.user_ids)} customers"
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"received response from updating prisma client. response={response}"
|
|
)
|
|
return {
|
|
"deleted_customers": response,
|
|
"message": "Successfully deleted customers with ids: "
|
|
+ str(data.user_ids),
|
|
}
|
|
else:
|
|
raise ValueError(f"user_id is required, passed user_id = {data.user_ids}")
|
|
|
|
# update based on remaining passed in values
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.delete_end_user(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
|
type="internal_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Internal Server Error, " + str(e),
|
|
type="internal_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
pass
|
|
|
|
|
|
async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs):
|
|
if premium_user is not True:
|
|
return
|
|
|
|
if litellm.store_audit_logs is not True:
|
|
return
|
|
if prisma_client is None:
|
|
raise Exception("prisma_client is None, no DB connected")
|
|
|
|
verbose_proxy_logger.debug("creating audit log for %s", request_data)
|
|
|
|
if isinstance(request_data.updated_values, dict):
|
|
request_data.updated_values = json.dumps(request_data.updated_values)
|
|
|
|
if isinstance(request_data.before_value, dict):
|
|
request_data.before_value = json.dumps(request_data.before_value)
|
|
|
|
_request_data = request_data.dict(exclude_none=True)
|
|
|
|
try:
|
|
await prisma_client.db.litellm_auditlog.create(
|
|
data={
|
|
**_request_data, # type: ignore
|
|
}
|
|
)
|
|
except Exception as e:
|
|
# [Non-Blocking Exception. Do not allow blocking LLM API call]
|
|
verbose_proxy_logger.error(f"Failed Creating audit log {e}")
|
|
|
|
return
|
|
|
|
|
|
#### ORGANIZATION MANAGEMENT ####
|
|
|
|
|
|
@router.post(
|
|
"/organization/new",
|
|
tags=["organization management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=NewOrganizationResponse,
|
|
)
|
|
async def new_organization(
|
|
data: NewOrganizationRequest,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Allow orgs to own teams
|
|
|
|
Set org level budgets + model access.
|
|
|
|
Only admins can create orgs.
|
|
|
|
# Parameters
|
|
|
|
- `organization_alias`: *str* = The name of the organization.
|
|
- `models`: *List* = The models the organization has access to.
|
|
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
|
|
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
|
|
- `max_budget`: *Optional[float]* = Max budget for org
|
|
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
|
|
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
|
|
- `model_max_budget`: *Optional[dict]* = Max budget for a specific model
|
|
- `budget_duration`: *Optional[str]* = Frequency of reseting org budget
|
|
|
|
Case 1: Create new org **without** a budget_id
|
|
|
|
```bash
|
|
curl --location 'http://0.0.0.0:4000/organization/new' \
|
|
|
|
--header 'Authorization: Bearer sk-1234' \
|
|
|
|
--header 'Content-Type: application/json' \
|
|
|
|
--data '{
|
|
"organization_alias": "my-secret-org",
|
|
"models": ["model1", "model2"],
|
|
"max_budget": 100
|
|
}'
|
|
|
|
|
|
```
|
|
|
|
Case 2: Create new org **with** a budget_id
|
|
|
|
```bash
|
|
curl --location 'http://0.0.0.0:4000/organization/new' \
|
|
|
|
--header 'Authorization: Bearer sk-1234' \
|
|
|
|
--header 'Content-Type: application/json' \
|
|
|
|
--data '{
|
|
"organization_alias": "my-secret-org",
|
|
"models": ["model1", "model2"],
|
|
"budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689"
|
|
}'
|
|
```
|
|
"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
|
|
|
if (
|
|
user_api_key_dict.user_role is None
|
|
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
|
):
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={
|
|
"error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}"
|
|
},
|
|
)
|
|
|
|
if data.budget_id is None:
|
|
"""
|
|
Every organization needs a budget attached.
|
|
|
|
If none provided, create one based on provided values
|
|
"""
|
|
budget_params = LiteLLM_BudgetTable.model_fields.keys()
|
|
|
|
# Only include Budget Params when creating an entry in litellm_budgettable
|
|
_json_data = data.json(exclude_none=True)
|
|
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
|
|
budget_row = LiteLLM_BudgetTable(**_budget_data)
|
|
|
|
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
|
|
|
|
_budget = await prisma_client.db.litellm_budgettable.create(
|
|
data={
|
|
**new_budget, # type: ignore
|
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
}
|
|
) # type: ignore
|
|
|
|
data.budget_id = _budget.budget_id
|
|
|
|
"""
|
|
Ensure only models that user has access to, are given to org
|
|
"""
|
|
if len(user_api_key_dict.models) == 0: # user has access to all models
|
|
pass
|
|
else:
|
|
if len(data.models) == 0:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": f"User not allowed to give access to all models. Select models you want org to have access to."
|
|
},
|
|
)
|
|
for m in data.models:
|
|
if m not in user_api_key_dict.models:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}"
|
|
},
|
|
)
|
|
organization_row = LiteLLM_OrganizationTable(
|
|
**data.json(exclude_none=True),
|
|
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
)
|
|
new_organization_row = prisma_client.jsonify_object(
|
|
organization_row.json(exclude_none=True)
|
|
)
|
|
response = await prisma_client.db.litellm_organizationtable.create(
|
|
data={
|
|
**new_organization_row, # type: ignore
|
|
}
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@router.post(
|
|
"/organization/update",
|
|
tags=["organization management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def update_organization():
|
|
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
|
|
pass
|
|
|
|
|
|
@router.post(
|
|
"/organization/delete",
|
|
tags=["organization management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def delete_organization():
|
|
"""[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues"""
|
|
pass
|
|
|
|
|
|
@router.post(
|
|
"/organization/info",
|
|
tags=["organization management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def info_organization(data: OrganizationRequest):
|
|
"""
|
|
Get the org specific information
|
|
"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
|
|
|
if len(data.organizations) == 0:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": f"Specify list of organization id's to query. Passed in={data.organizations}"
|
|
},
|
|
)
|
|
response = await prisma_client.db.litellm_organizationtable.find_many(
|
|
where={"organization_id": {"in": data.organizations}},
|
|
include={"litellm_budget_table": True},
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
#### BUDGET TABLE MANAGEMENT ####
|
|
|
|
|
|
@router.post(
|
|
"/budget/new",
|
|
tags=["budget management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def new_budget(
|
|
budget_obj: BudgetNew,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
|
|
"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
response = await prisma_client.db.litellm_budgettable.create(
|
|
data={
|
|
**budget_obj.model_dump(exclude_none=True), # type: ignore
|
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
} # type: ignore
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@router.post(
|
|
"/budget/info",
|
|
tags=["budget management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def info_budget(data: BudgetRequest):
|
|
"""
|
|
Get the budget id specific information
|
|
"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
|
|
|
if len(data.budgets) == 0:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
|
|
},
|
|
)
|
|
response = await prisma_client.db.litellm_budgettable.find_many(
|
|
where={"budget_id": {"in": data.budgets}},
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@router.get(
|
|
"/budget/settings",
|
|
tags=["budget management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def budget_settings(
|
|
budget_id: str,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Get list of configurable params + current value for a budget item + description of each field
|
|
|
|
Used on Admin UI.
|
|
"""
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
## get budget item from db
|
|
db_budget_row = await prisma_client.db.litellm_budgettable.find_first(
|
|
where={"budget_id": budget_id}
|
|
)
|
|
|
|
if db_budget_row is not None:
|
|
db_budget_row_dict = db_budget_row.model_dump(exclude_none=True)
|
|
else:
|
|
db_budget_row_dict = {}
|
|
|
|
allowed_args = {
|
|
"max_parallel_requests": {"type": "Integer"},
|
|
"tpm_limit": {"type": "Integer"},
|
|
"rpm_limit": {"type": "Integer"},
|
|
"budget_duration": {"type": "String"},
|
|
"max_budget": {"type": "Float"},
|
|
"soft_budget": {"type": "Float"},
|
|
}
|
|
|
|
return_val = []
|
|
|
|
for field_name, field_info in BudgetNew.model_fields.items():
|
|
if field_name in allowed_args:
|
|
|
|
_stored_in_db = True
|
|
|
|
_response_obj = ConfigList(
|
|
field_name=field_name,
|
|
field_type=allowed_args[field_name]["type"],
|
|
field_description=field_info.description or "",
|
|
field_value=db_budget_row_dict.get(field_name, None),
|
|
stored_in_db=_stored_in_db,
|
|
field_default_value=field_info.default,
|
|
)
|
|
return_val.append(_response_obj)
|
|
|
|
return return_val
|
|
|
|
|
|
@router.get(
|
|
"/budget/list",
|
|
tags=["budget management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def list_budget(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""List all the created budgets in proxy db. Used on Admin UI."""
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
response = await prisma_client.db.litellm_budgettable.find_many()
|
|
|
|
return response
|
|
|
|
|
|
@router.post(
|
|
"/budget/delete",
|
|
tags=["budget management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def delete_budget(
|
|
data: BudgetDeleteRequest,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""Delete budget"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
response = await prisma_client.db.litellm_budgettable.delete(
|
|
where={"budget_id": data.id}
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
#### MODEL MANAGEMENT ####
|
|
|
|
|
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
|
@router.post(
|
|
"/model/new",
|
|
description="Allows adding new models to the model list in the config.yaml",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def add_new_model(
|
|
model_params: Deployment,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj
|
|
try:
|
|
import base64
|
|
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
|
},
|
|
)
|
|
|
|
model_response = None
|
|
# update DB
|
|
if store_model_in_db == True:
|
|
"""
|
|
- store model_list in db
|
|
- store keys separately
|
|
"""
|
|
# encrypt litellm params #
|
|
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
|
_orignal_litellm_model_name = model_params.litellm_params.model
|
|
for k, v in _litellm_params_dict.items():
|
|
if isinstance(v, str):
|
|
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
|
|
model_params.litellm_params[k] = base64.b64encode(
|
|
encrypted_value
|
|
).decode("utf-8")
|
|
_data: dict = {
|
|
"model_id": model_params.model_info.id,
|
|
"model_name": model_params.model_name,
|
|
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
|
|
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
|
exclude_none=True
|
|
),
|
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
}
|
|
if model_params.model_info.id is not None:
|
|
_data["model_id"] = model_params.model_info.id
|
|
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
|
data=_data # type: ignore
|
|
)
|
|
|
|
await proxy_config.add_deployment(
|
|
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
|
)
|
|
try:
|
|
# don't let failed slack alert block the /model/new response
|
|
_alerting = general_settings.get("alerting", []) or []
|
|
if "slack" in _alerting:
|
|
# send notification - new model added
|
|
await proxy_logging_obj.slack_alerting_instance.model_added_alert(
|
|
model_name=model_params.model_name,
|
|
litellm_model_name=_orignal_litellm_model_name,
|
|
passed_model_info=model_params.model_info,
|
|
)
|
|
except:
|
|
pass
|
|
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
|
},
|
|
)
|
|
|
|
return model_response
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.add_new_model(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Authentication Error, " + str(e),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
|
|
#### MODEL MANAGEMENT ####
|
|
@router.post(
|
|
"/model/update",
|
|
description="Edit existing model params",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def update_model(
|
|
model_params: updateDeployment,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj
|
|
try:
|
|
import base64
|
|
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
|
},
|
|
)
|
|
# update DB
|
|
if store_model_in_db == True:
|
|
_model_id = None
|
|
_model_info = getattr(model_params, "model_info", None)
|
|
if _model_info is None:
|
|
raise Exception("model_info not provided")
|
|
|
|
_model_id = _model_info.id
|
|
if _model_id is None:
|
|
raise Exception("model_info.id not provided")
|
|
_existing_litellm_params = (
|
|
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
|
where={"model_id": _model_id}
|
|
)
|
|
)
|
|
if _existing_litellm_params is None:
|
|
if (
|
|
llm_router is not None
|
|
and llm_router.get_deployment(model_id=_model_id) is not None
|
|
):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit."
|
|
},
|
|
)
|
|
raise Exception("model not found")
|
|
_existing_litellm_params_dict = dict(
|
|
_existing_litellm_params.litellm_params
|
|
)
|
|
|
|
if model_params.litellm_params is None:
|
|
raise Exception("litellm_params not provided")
|
|
|
|
_new_litellm_params_dict = model_params.litellm_params.dict(
|
|
exclude_none=True
|
|
)
|
|
|
|
### ENCRYPT PARAMS ###
|
|
for k, v in _new_litellm_params_dict.items():
|
|
if isinstance(v, str):
|
|
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
|
|
model_params.litellm_params[k] = base64.b64encode(
|
|
encrypted_value
|
|
).decode("utf-8")
|
|
|
|
### MERGE WITH EXISTING DATA ###
|
|
merged_dictionary = {}
|
|
_mp = model_params.litellm_params.dict()
|
|
|
|
for key, value in _mp.items():
|
|
if value is not None:
|
|
merged_dictionary[key] = value
|
|
elif (
|
|
key in _existing_litellm_params_dict
|
|
and _existing_litellm_params_dict[key] is not None
|
|
):
|
|
merged_dictionary[key] = _existing_litellm_params_dict[key]
|
|
else:
|
|
pass
|
|
|
|
_data: dict = {
|
|
"litellm_params": json.dumps(merged_dictionary), # type: ignore
|
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
}
|
|
model_response = await prisma_client.db.litellm_proxymodeltable.update(
|
|
where={"model_id": _model_id},
|
|
data=_data, # type: ignore
|
|
)
|
|
|
|
return model_response
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.update_model(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Authentication Error, " + str(e),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/v2/model/info",
|
|
description="v2 - returns all the models set on the config.yaml, shows 'user_access' = True if the user has access to the model. Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def model_info_v2(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
model: Optional[str] = fastapi.Query(
|
|
None, description="Specify the model name (optional)"
|
|
),
|
|
debug: Optional[bool] = False,
|
|
):
|
|
"""
|
|
BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now.
|
|
"""
|
|
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
|
|
|
|
if llm_model_list is None or not isinstance(llm_model_list, list):
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": f"No model list passed, models={llm_model_list}. You can add a model through the config.yaml or on the LiteLLM Admin UI."
|
|
},
|
|
)
|
|
|
|
# Load existing config
|
|
config = await proxy_config.get_config()
|
|
|
|
all_models = copy.deepcopy(llm_model_list)
|
|
if user_model is not None:
|
|
# if user does not use a config.yaml, https://github.com/BerriAI/litellm/issues/2061
|
|
all_models += [user_model]
|
|
|
|
# check all models user has access to in user_api_key_dict
|
|
user_models = []
|
|
if len(user_api_key_dict.models) > 0:
|
|
user_models = user_api_key_dict.models
|
|
|
|
if model is not None:
|
|
all_models = [m for m in all_models if m["model_name"] == model]
|
|
|
|
# fill in model info based on config.yaml and litellm model_prices_and_context_window.json
|
|
for _model in all_models:
|
|
# provided model_info in config.yaml
|
|
model_info = _model.get("model_info", {})
|
|
if debug is True:
|
|
_openai_client = "None"
|
|
if llm_router is not None:
|
|
_openai_client = (
|
|
llm_router._get_client(
|
|
deployment=_model, kwargs={}, client_type="async"
|
|
)
|
|
or "None"
|
|
)
|
|
else:
|
|
_openai_client = "llm_router_is_None"
|
|
openai_client = str(_openai_client)
|
|
_model["openai_client"] = openai_client
|
|
|
|
# read litellm model_prices_and_context_window.json to get the following:
|
|
# input_cost_per_token, output_cost_per_token, max_tokens
|
|
litellm_model_info = get_litellm_model_info(model=_model)
|
|
|
|
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
|
if litellm_model_info == {}:
|
|
# use litellm_param model_name to get model_info
|
|
litellm_params = _model.get("litellm_params", {})
|
|
litellm_model = litellm_params.get("model", None)
|
|
try:
|
|
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
|
except Exception:
|
|
litellm_model_info = {}
|
|
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
|
if litellm_model_info == {}:
|
|
# use litellm_param model_name to get model_info
|
|
litellm_params = _model.get("litellm_params", {})
|
|
litellm_model = litellm_params.get("model", None)
|
|
split_model = litellm_model.split("/")
|
|
if len(split_model) > 0:
|
|
litellm_model = split_model[-1]
|
|
try:
|
|
litellm_model_info = litellm.get_model_info(
|
|
model=litellm_model, custom_llm_provider=split_model[0]
|
|
)
|
|
except Exception:
|
|
litellm_model_info = {}
|
|
for k, v in litellm_model_info.items():
|
|
if k not in model_info:
|
|
model_info[k] = v
|
|
_model["model_info"] = model_info
|
|
# don't return the api key / vertex credentials
|
|
# don't return the llm credentials
|
|
_model["litellm_params"].pop("api_key", None)
|
|
_model["litellm_params"].pop("vertex_credentials", None)
|
|
_model["litellm_params"].pop("aws_access_key_id", None)
|
|
_model["litellm_params"].pop("aws_secret_access_key", None)
|
|
|
|
verbose_proxy_logger.debug("all_models: %s", all_models)
|
|
return {"data": all_models}
|
|
|
|
|
|
@router.get(
|
|
"/model/streaming_metrics",
|
|
description="View time to first token for models in spend logs",
|
|
tags=["model management"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def model_streaming_metrics(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
_selected_model_group: Optional[str] = None,
|
|
startTime: Optional[datetime] = None,
|
|
endTime: Optional[datetime] = None,
|
|
):
|
|
global prisma_client, llm_router
|
|
if prisma_client is None:
|
|
raise ProxyException(
|
|
message=CommonProxyErrors.db_not_connected_error.value,
|
|
type="internal_error",
|
|
param="None",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
|
|
startTime = startTime or datetime.now() - timedelta(days=7) # show over past week
|
|
endTime = endTime or datetime.now()
|
|
|
|
is_same_day = startTime.date() == endTime.date()
|
|
if is_same_day:
|
|
sql_query = """
|
|
SELECT
|
|
api_base,
|
|
model_group,
|
|
model,
|
|
"startTime",
|
|
request_id,
|
|
EXTRACT(epoch FROM ("completionStartTime" - "startTime")) AS time_to_first_token
|
|
FROM
|
|
"LiteLLM_SpendLogs"
|
|
WHERE
|
|
"model_group" = $1 AND "cache_hit" != 'True'
|
|
AND "completionStartTime" IS NOT NULL
|
|
AND "completionStartTime" != "endTime"
|
|
AND DATE("startTime") = DATE($2::timestamp)
|
|
GROUP BY
|
|
api_base,
|
|
model_group,
|
|
model,
|
|
request_id
|
|
ORDER BY
|
|
time_to_first_token DESC;
|
|
"""
|
|
else:
|
|
sql_query = """
|
|
SELECT
|
|
api_base,
|
|
model_group,
|
|
model,
|
|
DATE_TRUNC('day', "startTime")::DATE AS day,
|
|
AVG(EXTRACT(epoch FROM ("completionStartTime" - "startTime"))) AS time_to_first_token
|
|
FROM
|
|
"LiteLLM_SpendLogs"
|
|
WHERE
|
|
"startTime" BETWEEN $2::timestamp AND $3::timestamp
|
|
AND "model_group" = $1 AND "cache_hit" != 'True'
|
|
AND "completionStartTime" IS NOT NULL
|
|
AND "completionStartTime" != "endTime"
|
|
GROUP BY
|
|
api_base,
|
|
model_group,
|
|
model,
|
|
day
|
|
ORDER BY
|
|
time_to_first_token DESC;
|
|
"""
|
|
|
|
_all_api_bases = set()
|
|
db_response = await prisma_client.db.query_raw(
|
|
sql_query, _selected_model_group, startTime, endTime
|
|
)
|
|
_daily_entries: dict = {} # {"Jun 23": {"model1": 0.002, "model2": 0.003}}
|
|
if db_response is not None:
|
|
for model_data in db_response:
|
|
_api_base = model_data["api_base"]
|
|
_model = model_data["model"]
|
|
time_to_first_token = model_data["time_to_first_token"]
|
|
unique_key = ""
|
|
if is_same_day:
|
|
_request_id = model_data["request_id"]
|
|
unique_key = _request_id
|
|
if _request_id not in _daily_entries:
|
|
_daily_entries[_request_id] = {}
|
|
else:
|
|
_day = model_data["day"]
|
|
unique_key = _day
|
|
time_to_first_token = model_data["time_to_first_token"]
|
|
if _day not in _daily_entries:
|
|
_daily_entries[_day] = {}
|
|
_combined_model_name = str(_model)
|
|
if "https://" in _api_base:
|
|
_combined_model_name = str(_api_base)
|
|
if "/openai/" in _combined_model_name:
|
|
_combined_model_name = _combined_model_name.split("/openai/")[0]
|
|
|
|
_all_api_bases.add(_combined_model_name)
|
|
|
|
_daily_entries[unique_key][_combined_model_name] = time_to_first_token
|
|
|
|
"""
|
|
each entry needs to be like this:
|
|
{
|
|
date: 'Jun 23',
|
|
'gpt-4-https://api.openai.com/v1/': 0.002,
|
|
'gpt-43-https://api.openai.com-12/v1/': 0.002,
|
|
}
|
|
"""
|
|
# convert daily entries to list of dicts
|
|
|
|
response: List[dict] = []
|
|
|
|
# sort daily entries by date
|
|
_daily_entries = dict(sorted(_daily_entries.items(), key=lambda item: item[0]))
|
|
for day in _daily_entries:
|
|
entry = {"date": str(day)}
|
|
for model_key, latency in _daily_entries[day].items():
|
|
entry[model_key] = latency
|
|
response.append(entry)
|
|
|
|
return {
|
|
"data": response,
|
|
"all_api_bases": list(_all_api_bases),
|
|
}
|
|
|
|
|
|
@router.get(
|
|
"/model/metrics",
|
|
description="View number of requests & avg latency per model on config.yaml",
|
|
tags=["model management"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def model_metrics(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
_selected_model_group: Optional[str] = "gpt-4-32k",
|
|
startTime: Optional[datetime] = None,
|
|
endTime: Optional[datetime] = None,
|
|
api_key: Optional[str] = None,
|
|
customer: Optional[str] = None,
|
|
):
|
|
global prisma_client, llm_router
|
|
if prisma_client is None:
|
|
raise ProxyException(
|
|
message="Prisma Client is not initialized",
|
|
type="internal_error",
|
|
param="None",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
startTime = startTime or datetime.now() - timedelta(days=30)
|
|
endTime = endTime or datetime.now()
|
|
|
|
if api_key is None or api_key == "undefined":
|
|
api_key = "null"
|
|
|
|
if customer is None or customer == "undefined":
|
|
customer = "null"
|
|
|
|
sql_query = """
|
|
SELECT
|
|
api_base,
|
|
model_group,
|
|
model,
|
|
DATE_TRUNC('day', "startTime")::DATE AS day,
|
|
AVG(EXTRACT(epoch FROM ("endTime" - "startTime")) / NULLIF("completion_tokens", 0)) AS avg_latency_per_token
|
|
FROM
|
|
"LiteLLM_SpendLogs"
|
|
WHERE
|
|
"startTime" >= $2::timestamp AND "startTime" <= $3::timestamp
|
|
AND "model_group" = $1 AND "cache_hit" != 'True'
|
|
AND (
|
|
CASE
|
|
WHEN $4 != 'null' THEN "api_key" = $4
|
|
ELSE TRUE
|
|
END
|
|
)
|
|
AND (
|
|
CASE
|
|
WHEN $5 != 'null' THEN "end_user" = $5
|
|
ELSE TRUE
|
|
END
|
|
)
|
|
GROUP BY
|
|
api_base,
|
|
model_group,
|
|
model,
|
|
day
|
|
HAVING
|
|
SUM(completion_tokens) > 0
|
|
ORDER BY
|
|
avg_latency_per_token DESC;
|
|
"""
|
|
_all_api_bases = set()
|
|
db_response = await prisma_client.db.query_raw(
|
|
sql_query, _selected_model_group, startTime, endTime, api_key, customer
|
|
)
|
|
_daily_entries: dict = {} # {"Jun 23": {"model1": 0.002, "model2": 0.003}}
|
|
|
|
if db_response is not None:
|
|
for model_data in db_response:
|
|
_api_base = model_data["api_base"]
|
|
_model = model_data["model"]
|
|
_day = model_data["day"]
|
|
_avg_latency_per_token = model_data["avg_latency_per_token"]
|
|
if _day not in _daily_entries:
|
|
_daily_entries[_day] = {}
|
|
_combined_model_name = str(_model)
|
|
if "https://" in _api_base:
|
|
_combined_model_name = str(_api_base)
|
|
if "/openai/" in _combined_model_name:
|
|
_combined_model_name = _combined_model_name.split("/openai/")[0]
|
|
|
|
_all_api_bases.add(_combined_model_name)
|
|
_daily_entries[_day][_combined_model_name] = _avg_latency_per_token
|
|
|
|
"""
|
|
each entry needs to be like this:
|
|
{
|
|
date: 'Jun 23',
|
|
'gpt-4-https://api.openai.com/v1/': 0.002,
|
|
'gpt-43-https://api.openai.com-12/v1/': 0.002,
|
|
}
|
|
"""
|
|
# convert daily entries to list of dicts
|
|
|
|
response: List[dict] = []
|
|
|
|
# sort daily entries by date
|
|
_daily_entries = dict(sorted(_daily_entries.items(), key=lambda item: item[0]))
|
|
for day in _daily_entries:
|
|
entry = {"date": str(day)}
|
|
for model_key, latency in _daily_entries[day].items():
|
|
entry[model_key] = latency
|
|
response.append(entry)
|
|
|
|
return {
|
|
"data": response,
|
|
"all_api_bases": list(_all_api_bases),
|
|
}
|
|
|
|
|
|
@router.get(
|
|
"/model/metrics/slow_responses",
|
|
description="View number of hanging requests per model_group",
|
|
tags=["model management"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def model_metrics_slow_responses(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
_selected_model_group: Optional[str] = "gpt-4-32k",
|
|
startTime: Optional[datetime] = None,
|
|
endTime: Optional[datetime] = None,
|
|
api_key: Optional[str] = None,
|
|
customer: Optional[str] = None,
|
|
):
|
|
global prisma_client, llm_router, proxy_logging_obj
|
|
if prisma_client is None:
|
|
raise ProxyException(
|
|
message="Prisma Client is not initialized",
|
|
type="internal_error",
|
|
param="None",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
if api_key is None or api_key == "undefined":
|
|
api_key = "null"
|
|
|
|
if customer is None or customer == "undefined":
|
|
customer = "null"
|
|
|
|
startTime = startTime or datetime.now() - timedelta(days=30)
|
|
endTime = endTime or datetime.now()
|
|
|
|
alerting_threshold = (
|
|
proxy_logging_obj.slack_alerting_instance.alerting_threshold or 300
|
|
)
|
|
alerting_threshold = int(alerting_threshold)
|
|
|
|
sql_query = """
|
|
SELECT
|
|
api_base,
|
|
COUNT(*) AS total_count,
|
|
SUM(CASE
|
|
WHEN ("endTime" - "startTime") >= (INTERVAL '1 SECOND' * CAST($1 AS INTEGER)) THEN 1
|
|
ELSE 0
|
|
END) AS slow_count
|
|
FROM
|
|
"LiteLLM_SpendLogs"
|
|
WHERE
|
|
"model_group" = $2
|
|
AND "cache_hit" != 'True'
|
|
AND "startTime" >= $3::timestamp
|
|
AND "startTime" <= $4::timestamp
|
|
AND (
|
|
CASE
|
|
WHEN $5 != 'null' THEN "api_key" = $5
|
|
ELSE TRUE
|
|
END
|
|
)
|
|
AND (
|
|
CASE
|
|
WHEN $6 != 'null' THEN "end_user" = $6
|
|
ELSE TRUE
|
|
END
|
|
)
|
|
GROUP BY
|
|
api_base
|
|
ORDER BY
|
|
slow_count DESC;
|
|
"""
|
|
|
|
db_response = await prisma_client.db.query_raw(
|
|
sql_query,
|
|
alerting_threshold,
|
|
_selected_model_group,
|
|
startTime,
|
|
endTime,
|
|
api_key,
|
|
customer,
|
|
)
|
|
|
|
if db_response is not None:
|
|
for row in db_response:
|
|
_api_base = row.get("api_base") or ""
|
|
if "/openai/" in _api_base:
|
|
_api_base = _api_base.split("/openai/")[0]
|
|
row["api_base"] = _api_base
|
|
return db_response
|
|
|
|
|
|
@router.get(
|
|
"/model/metrics/exceptions",
|
|
description="View number of failed requests per model on config.yaml",
|
|
tags=["model management"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def model_metrics_exceptions(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
_selected_model_group: Optional[str] = None,
|
|
startTime: Optional[datetime] = None,
|
|
endTime: Optional[datetime] = None,
|
|
api_key: Optional[str] = None,
|
|
customer: Optional[str] = None,
|
|
):
|
|
global prisma_client, llm_router
|
|
if prisma_client is None:
|
|
raise ProxyException(
|
|
message="Prisma Client is not initialized",
|
|
type="internal_error",
|
|
param="None",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
|
|
startTime = startTime or datetime.now() - timedelta(days=30)
|
|
endTime = endTime or datetime.now()
|
|
|
|
if api_key is None or api_key == "undefined":
|
|
api_key = "null"
|
|
|
|
"""
|
|
"""
|
|
sql_query = """
|
|
WITH cte AS (
|
|
SELECT
|
|
CASE WHEN api_base = '' THEN litellm_model_name ELSE CONCAT(litellm_model_name, '-', api_base) END AS combined_model_api_base,
|
|
exception_type,
|
|
COUNT(*) AS num_rate_limit_exceptions
|
|
FROM "LiteLLM_ErrorLogs"
|
|
WHERE
|
|
"startTime" >= $1::timestamp
|
|
AND "endTime" <= $2::timestamp
|
|
AND model_group = $3
|
|
GROUP BY combined_model_api_base, exception_type
|
|
)
|
|
SELECT
|
|
combined_model_api_base,
|
|
COUNT(*) AS total_exceptions,
|
|
json_object_agg(exception_type, num_rate_limit_exceptions) AS exception_counts
|
|
FROM cte
|
|
GROUP BY combined_model_api_base
|
|
ORDER BY total_exceptions DESC
|
|
LIMIT 200;
|
|
"""
|
|
db_response = await prisma_client.db.query_raw(
|
|
sql_query, startTime, endTime, _selected_model_group, api_key
|
|
)
|
|
response: List[dict] = []
|
|
exception_types = set()
|
|
|
|
"""
|
|
Return Data
|
|
{
|
|
"combined_model_api_base": "gpt-3.5-turbo-https://api.openai.com/v1/,
|
|
"total_exceptions": 5,
|
|
"BadRequestException": 5,
|
|
"TimeoutException": 2
|
|
}
|
|
"""
|
|
|
|
if db_response is not None:
|
|
# loop through all models
|
|
for model_data in db_response:
|
|
model = model_data.get("combined_model_api_base", "")
|
|
total_exceptions = model_data.get("total_exceptions", 0)
|
|
exception_counts = model_data.get("exception_counts", {})
|
|
curr_row = {
|
|
"model": model,
|
|
"total_exceptions": total_exceptions,
|
|
}
|
|
curr_row.update(exception_counts)
|
|
response.append(curr_row)
|
|
for k, v in exception_counts.items():
|
|
exception_types.add(k)
|
|
|
|
return {"data": response, "exception_types": list(exception_types)}
|
|
|
|
|
|
@router.get(
|
|
"/model/info",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
@router.get(
|
|
"/v1/model/info",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def model_info_v1(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
litellm_model_id: Optional[str] = None,
|
|
):
|
|
"""
|
|
Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)
|
|
|
|
Parameters:
|
|
litellm_model_id: Optional[str] = None (this is the value of `x-litellm-model-id` returned in response headers)
|
|
|
|
- When litellm_model_id is passed, it will return the info for that specific model
|
|
- When litellm_model_id is not passed, it will return the info for all models
|
|
|
|
Returns:
|
|
Returns a dictionary containing information about each model.
|
|
|
|
Example Response:
|
|
```json
|
|
{
|
|
"data": [
|
|
{
|
|
"model_name": "fake-openai-endpoint",
|
|
"litellm_params": {
|
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
|
"model": "openai/fake"
|
|
},
|
|
"model_info": {
|
|
"id": "112f74fab24a7a5245d2ced3536dd8f5f9192c57ee6e332af0f0512e08bed5af",
|
|
"db_model": false
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
```
|
|
"""
|
|
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
|
|
|
|
if llm_model_list is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
|
)
|
|
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "LLM Router is not loaded in. Make sure you passed models in your config.yaml or on the LiteLLM Admin UI."
|
|
},
|
|
)
|
|
|
|
if litellm_model_id is not None:
|
|
# user is trying to get specific model from litellm router
|
|
deployment_info = llm_router.get_deployment(model_id=litellm_model_id)
|
|
if deployment_info is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail={
|
|
"error": f"Model id = {litellm_model_id} not found on litellm proxy"
|
|
},
|
|
)
|
|
_deployment_info_dict = deployment_info.model_dump()
|
|
_deployment_info_dict = remove_sensitive_info_from_deployment(
|
|
deployment_dict=_deployment_info_dict
|
|
)
|
|
return {"data": _deployment_info_dict}
|
|
|
|
all_models: List[dict] = []
|
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
|
if llm_model_list is None:
|
|
proxy_model_list = []
|
|
else:
|
|
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
|
key_models = get_key_models(
|
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
|
)
|
|
team_models = get_team_models(
|
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
|
)
|
|
all_models_str = get_complete_model_list(
|
|
key_models=key_models,
|
|
team_models=team_models,
|
|
proxy_model_list=proxy_model_list,
|
|
user_model=user_model,
|
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
|
)
|
|
|
|
if len(all_models_str) > 0:
|
|
model_names = all_models_str
|
|
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
|
|
all_models = copy.deepcopy(_relevant_models)
|
|
|
|
for model in all_models:
|
|
# provided model_info in config.yaml
|
|
model_info = model.get("model_info", {})
|
|
|
|
# read litellm model_prices_and_context_window.json to get the following:
|
|
# input_cost_per_token, output_cost_per_token, max_tokens
|
|
litellm_model_info = get_litellm_model_info(model=model)
|
|
|
|
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
|
if litellm_model_info == {}:
|
|
# use litellm_param model_name to get model_info
|
|
litellm_params = model.get("litellm_params", {})
|
|
litellm_model = litellm_params.get("model", None)
|
|
try:
|
|
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
|
except:
|
|
litellm_model_info = {}
|
|
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
|
if litellm_model_info == {}:
|
|
# use litellm_param model_name to get model_info
|
|
litellm_params = model.get("litellm_params", {})
|
|
litellm_model = litellm_params.get("model", None)
|
|
split_model = litellm_model.split("/")
|
|
if len(split_model) > 0:
|
|
litellm_model = split_model[-1]
|
|
try:
|
|
litellm_model_info = litellm.get_model_info(
|
|
model=litellm_model, custom_llm_provider=split_model[0]
|
|
)
|
|
except:
|
|
litellm_model_info = {}
|
|
for k, v in litellm_model_info.items():
|
|
if k not in model_info:
|
|
model_info[k] = v
|
|
model["model_info"] = model_info
|
|
# don't return the llm credentials
|
|
model = remove_sensitive_info_from_deployment(deployment_dict=model)
|
|
|
|
verbose_proxy_logger.debug("all_models: %s", all_models)
|
|
return {"data": all_models}
|
|
|
|
|
|
@router.get(
|
|
"/model_group/info",
|
|
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def model_group_info(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Returns model info at the model group level.
|
|
"""
|
|
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
|
|
|
|
if llm_model_list is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
|
)
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": "LLM Router is not loaded in"}
|
|
)
|
|
all_models: List[dict] = []
|
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
|
if llm_model_list is None:
|
|
proxy_model_list = []
|
|
else:
|
|
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
|
key_models = get_key_models(
|
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
|
)
|
|
team_models = get_team_models(
|
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
|
)
|
|
all_models_str = get_complete_model_list(
|
|
key_models=key_models,
|
|
team_models=team_models,
|
|
proxy_model_list=proxy_model_list,
|
|
user_model=user_model,
|
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
|
)
|
|
|
|
model_groups: List[ModelGroupInfo] = []
|
|
for model in all_models_str:
|
|
|
|
_model_group_info = llm_router.get_model_group_info(model_group=model)
|
|
if _model_group_info is not None:
|
|
model_groups.append(_model_group_info)
|
|
|
|
return {"data": model_groups}
|
|
|
|
|
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
|
@router.post(
|
|
"/model/delete",
|
|
description="Allows deleting models in the model list in the config.yaml",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def delete_model(model_info: ModelInfoDelete):
|
|
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
|
try:
|
|
"""
|
|
[BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
|
|
|
- Check if id in db
|
|
- Delete
|
|
"""
|
|
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
|
},
|
|
)
|
|
|
|
# update DB
|
|
if store_model_in_db == True:
|
|
"""
|
|
- store model_list in db
|
|
- store keys separately
|
|
"""
|
|
# encrypt litellm params #
|
|
result = await prisma_client.db.litellm_proxymodeltable.delete(
|
|
where={"model_id": model_info.id}
|
|
)
|
|
|
|
if result is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": f"Model with id={model_info.id} not found in db"},
|
|
)
|
|
|
|
## DELETE FROM ROUTER ##
|
|
if llm_router is not None:
|
|
llm_router.delete_deployment(id=model_info.id)
|
|
|
|
return {"message": f"Model: {result.model_id} deleted successfully"}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Authentication Error, " + str(e),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/model/settings",
|
|
description="Returns provider name, description, and required parameters for each provider",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def model_settings():
|
|
"""
|
|
Used by UI to generate 'model add' page
|
|
{
|
|
field_name=field_name,
|
|
field_type=allowed_args[field_name]["type"], # string/int
|
|
field_description=field_info.description or "", # human-friendly description
|
|
field_value=general_settings.get(field_name, None), # example value
|
|
}
|
|
"""
|
|
|
|
returned_list = []
|
|
for provider in litellm.provider_list:
|
|
returned_list.append(
|
|
ProviderInfo(
|
|
name=provider,
|
|
fields=litellm.get_provider_fields(custom_llm_provider=provider),
|
|
)
|
|
)
|
|
|
|
return returned_list
|
|
|
|
|
|
#### ALERTING MANAGEMENT ENDPOINTS ####
|
|
|
|
|
|
@router.get(
|
|
"/alerting/settings",
|
|
description="Return the configurable alerting param, description, and current value",
|
|
tags=["alerting"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def alerting_settings(
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global proxy_logging_obj, prisma_client
|
|
"""
|
|
Used by UI to generate 'alerting settings' page
|
|
{
|
|
field_name=field_name,
|
|
field_type=allowed_args[field_name]["type"], # string/int
|
|
field_description=field_info.description or "", # human-friendly description
|
|
field_value=general_settings.get(field_name, None), # example value
|
|
}
|
|
"""
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
## get general settings from db
|
|
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
|
where={"param_name": "general_settings"}
|
|
)
|
|
|
|
if db_general_settings is not None and db_general_settings.param_value is not None:
|
|
db_general_settings_dict = dict(db_general_settings.param_value)
|
|
alerting_args_dict: dict = db_general_settings_dict.get("alerting_args", {}) # type: ignore
|
|
else:
|
|
alerting_args_dict = {}
|
|
|
|
allowed_args = {
|
|
"daily_report_frequency": {"type": "Integer"},
|
|
"report_check_interval": {"type": "Integer"},
|
|
"budget_alert_ttl": {"type": "Integer"},
|
|
"outage_alert_ttl": {"type": "Integer"},
|
|
"region_outage_alert_ttl": {"type": "Integer"},
|
|
"minor_outage_alert_threshold": {"type": "Integer"},
|
|
"major_outage_alert_threshold": {"type": "Integer"},
|
|
"max_outage_alert_list_size": {"type": "Integer"},
|
|
}
|
|
|
|
_slack_alerting: SlackAlerting = proxy_logging_obj.slack_alerting_instance
|
|
_slack_alerting_args_dict = _slack_alerting.alerting_args.model_dump()
|
|
|
|
return_val = []
|
|
|
|
for field_name, field_info in SlackAlertingArgs.model_fields.items():
|
|
if field_name in allowed_args:
|
|
|
|
_stored_in_db: Optional[bool] = None
|
|
if field_name in alerting_args_dict:
|
|
_stored_in_db = True
|
|
else:
|
|
_stored_in_db = False
|
|
|
|
_response_obj = ConfigList(
|
|
field_name=field_name,
|
|
field_type=allowed_args[field_name]["type"],
|
|
field_description=field_info.description or "",
|
|
field_value=_slack_alerting_args_dict.get(field_name, None),
|
|
stored_in_db=_stored_in_db,
|
|
field_default_value=field_info.default,
|
|
premium_field=(
|
|
True if field_name == "region_outage_alert_ttl" else False
|
|
),
|
|
)
|
|
return_val.append(_response_obj)
|
|
return return_val
|
|
|
|
|
|
#### EXPERIMENTAL QUEUING ####
|
|
@router.post(
|
|
"/queue/chat/completions",
|
|
tags=["experimental"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def async_queue_request(
|
|
request: Request,
|
|
fastapi_response: Response,
|
|
model: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global general_settings, user_debug, proxy_logging_obj
|
|
"""
|
|
v2 attempt at a background worker to handle queuing.
|
|
|
|
Just supports /chat/completion calls currently.
|
|
|
|
Now using a FastAPI background task + /chat/completions compatible endpoint
|
|
"""
|
|
data = {}
|
|
try:
|
|
data = await request.json() # type: ignore
|
|
|
|
# Include original request and headers in the data
|
|
data["proxy_server_request"] = {
|
|
"url": str(request.url),
|
|
"method": request.method,
|
|
"headers": dict(request.headers),
|
|
"body": copy.copy(data), # use copy instead of deepcopy
|
|
}
|
|
|
|
verbose_proxy_logger.debug("receiving data: %s", data)
|
|
data["model"] = (
|
|
general_settings.get("completion_model", None) # server default
|
|
or user_model # model name passed via cli args
|
|
or model # for azure deployments
|
|
or data["model"] # default passed in http request
|
|
)
|
|
|
|
# users can pass in 'user' param to /chat/completions. Don't override it
|
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
|
# if users are using user_api_key_auth, set `user` in `data`
|
|
data["user"] = user_api_key_dict.user_id
|
|
|
|
if "metadata" not in data:
|
|
data["metadata"] = {}
|
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
|
_headers = dict(request.headers)
|
|
_headers.pop(
|
|
"authorization", None
|
|
) # do not store the original `sk-..` api key in the db
|
|
data["metadata"]["headers"] = _headers
|
|
data["metadata"]["user_api_key_alias"] = getattr(
|
|
user_api_key_dict, "key_alias", None
|
|
)
|
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
|
user_api_key_dict, "team_id", None
|
|
)
|
|
data["metadata"]["endpoint"] = str(request.url)
|
|
|
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
|
# override with user settings, these are params passed via cli
|
|
if user_temperature:
|
|
data["temperature"] = user_temperature
|
|
if user_request_timeout:
|
|
data["request_timeout"] = user_request_timeout
|
|
if user_max_tokens:
|
|
data["max_tokens"] = user_max_tokens
|
|
if user_api_base:
|
|
data["api_base"] = user_api_base
|
|
|
|
if llm_router is None:
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
|
)
|
|
|
|
response = await llm_router.schedule_acompletion(**data)
|
|
|
|
if (
|
|
"stream" in data and data["stream"] == True
|
|
): # use generate_responses to stream responses
|
|
return StreamingResponse(
|
|
async_data_generator(
|
|
user_api_key_dict=user_api_key_dict,
|
|
response=response,
|
|
request_data=data,
|
|
),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
fastapi_response.headers.update({"x-litellm-priority": str(data["priority"])})
|
|
return response
|
|
except Exception as e:
|
|
await proxy_logging_obj.post_call_failure_hook(
|
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
)
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Authentication Error, " + str(e),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
|
|
#### LOGIN ENDPOINTS ####
|
|
|
|
|
|
@app.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
|
|
async def google_login(request: Request):
|
|
"""
|
|
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
|
Example:
|
|
"""
|
|
global premium_user, prisma_client, master_key
|
|
|
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
|
|
|
####### Check if user is a Enterprise / Premium User #######
|
|
if (
|
|
microsoft_client_id is not None
|
|
or google_client_id is not None
|
|
or generic_client_id is not None
|
|
):
|
|
if premium_user != True:
|
|
raise ProxyException(
|
|
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
|
type="auth_error",
|
|
param="premium_user",
|
|
code=status.HTTP_403_FORBIDDEN,
|
|
)
|
|
|
|
####### Detect DB + MASTER KEY in .env #######
|
|
if prisma_client is None or master_key is None:
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
return HTMLResponse(content=missing_keys_html_form, status_code=200)
|
|
|
|
# get url from request
|
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
|
ui_username = os.getenv("UI_USERNAME")
|
|
if redirect_url.endswith("/"):
|
|
redirect_url += "sso/callback"
|
|
else:
|
|
redirect_url += "/sso/callback"
|
|
# Google SSO Auth
|
|
if google_client_id is not None:
|
|
from fastapi_sso.sso.google import GoogleSSO
|
|
|
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
|
if google_client_secret is None:
|
|
raise ProxyException(
|
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GOOGLE_CLIENT_SECRET",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
google_sso = GoogleSSO(
|
|
client_id=google_client_id,
|
|
client_secret=google_client_secret,
|
|
redirect_uri=redirect_url,
|
|
)
|
|
verbose_proxy_logger.info(
|
|
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
|
|
)
|
|
with google_sso:
|
|
return await google_sso.get_login_redirect()
|
|
# Microsoft SSO Auth
|
|
elif microsoft_client_id is not None:
|
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
|
|
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
|
if microsoft_client_secret is None:
|
|
raise ProxyException(
|
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="MICROSOFT_CLIENT_SECRET",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
microsoft_sso = MicrosoftSSO(
|
|
client_id=microsoft_client_id,
|
|
client_secret=microsoft_client_secret,
|
|
tenant=microsoft_tenant,
|
|
redirect_uri=redirect_url,
|
|
allow_insecure_http=True,
|
|
)
|
|
with microsoft_sso:
|
|
return await microsoft_sso.get_login_redirect()
|
|
elif generic_client_id is not None:
|
|
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
|
|
|
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
|
generic_authorization_endpoint = os.getenv(
|
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
|
)
|
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
|
if generic_client_secret is None:
|
|
raise ProxyException(
|
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GENERIC_CLIENT_SECRET",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
if generic_authorization_endpoint is None:
|
|
raise ProxyException(
|
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
if generic_token_endpoint is None:
|
|
raise ProxyException(
|
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GENERIC_TOKEN_ENDPOINT",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
if generic_userinfo_endpoint is None:
|
|
raise ProxyException(
|
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GENERIC_USERINFO_ENDPOINT",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
|
)
|
|
discovery = DiscoveryDocument(
|
|
authorization_endpoint=generic_authorization_endpoint,
|
|
token_endpoint=generic_token_endpoint,
|
|
userinfo_endpoint=generic_userinfo_endpoint,
|
|
)
|
|
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
|
generic_sso = SSOProvider(
|
|
client_id=generic_client_id,
|
|
client_secret=generic_client_secret,
|
|
redirect_uri=redirect_url,
|
|
allow_insecure_http=True,
|
|
scope=generic_scope,
|
|
)
|
|
with generic_sso:
|
|
# TODO: state should be a random string and added to the user session with cookie
|
|
# or a cryptographicly signed state that we can verify stateless
|
|
# For simplification we are using a static state, this is not perfect but some
|
|
# SSO providers do not allow stateless verification
|
|
redirect_params = {}
|
|
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
|
if state:
|
|
redirect_params["state"] = state
|
|
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
|
elif ui_username is not None:
|
|
# No Google, Microsoft SSO
|
|
# Use UI Credentials set in .env
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
return HTMLResponse(content=html_form, status_code=200)
|
|
else:
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
return HTMLResponse(content=html_form, status_code=200)
|
|
|
|
|
|
@app.get("/fallback/login", tags=["experimental"], include_in_schema=False)
|
|
async def fallback_login(request: Request):
|
|
"""
|
|
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
|
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
|
Example:
|
|
"""
|
|
# get url from request
|
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
|
ui_username = os.getenv("UI_USERNAME")
|
|
if redirect_url.endswith("/"):
|
|
redirect_url += "sso/callback"
|
|
else:
|
|
redirect_url += "/sso/callback"
|
|
|
|
if ui_username is not None:
|
|
# No Google, Microsoft SSO
|
|
# Use UI Credentials set in .env
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
return HTMLResponse(content=html_form, status_code=200)
|
|
else:
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
return HTMLResponse(content=html_form, status_code=200)
|
|
|
|
|
|
@router.post(
|
|
"/login", include_in_schema=False
|
|
) # hidden since this is a helper for UI sso login
|
|
async def login(request: Request):
|
|
global premium_user
|
|
try:
|
|
import multipart
|
|
except ImportError:
|
|
subprocess.run(["pip", "install", "python-multipart"])
|
|
global master_key
|
|
if master_key is None:
|
|
raise ProxyException(
|
|
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
|
type="auth_error",
|
|
param="master_key",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
form = await request.form()
|
|
username = str(form.get("username"))
|
|
password = str(form.get("password"))
|
|
ui_username = os.getenv("UI_USERNAME", "admin")
|
|
ui_password = os.getenv("UI_PASSWORD", None)
|
|
if ui_password is None:
|
|
ui_password = str(master_key) if master_key is not None else None
|
|
if ui_password is None:
|
|
raise ProxyException(
|
|
message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
|
type="auth_error",
|
|
param="UI_PASSWORD",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
|
|
# check if we can find the `username` in the db. on the ui, users can enter username=their email
|
|
_user_row = None
|
|
if prisma_client is not None:
|
|
_user_row = await prisma_client.db.litellm_usertable.find_first(
|
|
where={"user_email": {"equals": username}}
|
|
)
|
|
"""
|
|
To login to Admin UI, we support the following
|
|
- Login with UI_USERNAME and UI_PASSWORD
|
|
- Login with Invite Link `user_email` and `password` combination
|
|
"""
|
|
if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
|
|
password, ui_password
|
|
):
|
|
# Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
|
|
user_role = LitellmUserRoles.PROXY_ADMIN
|
|
user_id = username
|
|
|
|
# we want the key created to have PROXY_ADMIN_PERMISSIONS
|
|
key_user_id = litellm_proxy_admin_name
|
|
if (
|
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
|
and os.environ["PROXY_ADMIN_ID"] == user_id
|
|
) or user_id == "admin":
|
|
# checks if user is admin
|
|
key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id")
|
|
|
|
# Admin is Authe'd in - generate key for the UI to access Proxy
|
|
|
|
# ensure this user is set as the proxy admin, in this route there is no sso, we can assume this user is only the admin
|
|
await user_update(
|
|
data=UpdateUserRequest(
|
|
user_id=key_user_id,
|
|
user_role=user_role,
|
|
)
|
|
)
|
|
if os.getenv("DATABASE_URL") is not None:
|
|
response = await generate_key_helper_fn(
|
|
request_type="key",
|
|
**{
|
|
"user_role": LitellmUserRoles.PROXY_ADMIN,
|
|
"duration": "2hr",
|
|
"key_max_budget": 5,
|
|
"models": [],
|
|
"aliases": {},
|
|
"config": {},
|
|
"spend": 0,
|
|
"user_id": key_user_id,
|
|
"team_id": "litellm-dashboard",
|
|
}, # type: ignore
|
|
)
|
|
else:
|
|
raise ProxyException(
|
|
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
|
type="auth_error",
|
|
param="DATABASE_URL",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
key = response["token"] # type: ignore
|
|
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
|
if litellm_dashboard_ui.endswith("/"):
|
|
litellm_dashboard_ui += "ui/"
|
|
else:
|
|
litellm_dashboard_ui += "/ui/"
|
|
import jwt
|
|
|
|
jwt_token = jwt.encode(
|
|
{
|
|
"user_id": user_id,
|
|
"key": key,
|
|
"user_email": user_id,
|
|
"user_role": user_role, # this is the path without sso - we can assume only admins will use this
|
|
"login_method": "username_password",
|
|
"premium_user": premium_user,
|
|
},
|
|
master_key,
|
|
algorithm="HS256",
|
|
)
|
|
litellm_dashboard_ui += "?userID=" + user_id
|
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
|
redirect_response.set_cookie(key="token", value=jwt_token)
|
|
return redirect_response
|
|
elif _user_row is not None:
|
|
"""
|
|
When sharing invite links
|
|
|
|
-> if the user has no role in the DB assume they are only a viewer
|
|
"""
|
|
user_id = getattr(_user_row, "user_id", "unknown")
|
|
user_role = getattr(
|
|
_user_row, "user_role", LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
|
)
|
|
user_email = getattr(_user_row, "user_email", "unknown")
|
|
_password = getattr(_user_row, "password", "unknown")
|
|
|
|
# check if password == _user_row.password
|
|
hash_password = hash_token(token=password)
|
|
if secrets.compare_digest(password, _password) or secrets.compare_digest(
|
|
hash_password, _password
|
|
):
|
|
if os.getenv("DATABASE_URL") is not None:
|
|
response = await generate_key_helper_fn(
|
|
request_type="key",
|
|
**{ # type: ignore
|
|
"user_role": user_role,
|
|
"duration": "2hr",
|
|
"key_max_budget": 5,
|
|
"models": [],
|
|
"aliases": {},
|
|
"config": {},
|
|
"spend": 0,
|
|
"user_id": user_id,
|
|
"team_id": "litellm-dashboard",
|
|
},
|
|
)
|
|
else:
|
|
raise ProxyException(
|
|
message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
|
|
type="auth_error",
|
|
param="DATABASE_URL",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
key = response["token"] # type: ignore
|
|
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
|
if litellm_dashboard_ui.endswith("/"):
|
|
litellm_dashboard_ui += "ui/"
|
|
else:
|
|
litellm_dashboard_ui += "/ui/"
|
|
import jwt
|
|
|
|
jwt_token = jwt.encode(
|
|
{
|
|
"user_id": user_id,
|
|
"key": key,
|
|
"user_email": user_email,
|
|
"user_role": user_role,
|
|
"login_method": "username_password",
|
|
"premium_user": premium_user,
|
|
},
|
|
master_key,
|
|
algorithm="HS256",
|
|
)
|
|
litellm_dashboard_ui += "?userID=" + user_id
|
|
redirect_response = RedirectResponse(
|
|
url=litellm_dashboard_ui, status_code=303
|
|
)
|
|
redirect_response.set_cookie(key="token", value=jwt_token)
|
|
return redirect_response
|
|
else:
|
|
raise ProxyException(
|
|
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nNot valid credentials for {username}",
|
|
type="auth_error",
|
|
param="invalid_credentials",
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
|
)
|
|
else:
|
|
raise ProxyException(
|
|
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
|
|
type="auth_error",
|
|
param="invalid_credentials",
|
|
code=status.HTTP_401_UNAUTHORIZED,
|
|
)
|
|
|
|
|
|
@app.get(
|
|
"/sso/get/logout_url",
|
|
tags=["experimental"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def get_logout_url(request: Request):
|
|
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
|
_logout_url = os.getenv("PROXY_LOGOUT_URL", None)
|
|
|
|
return {"PROXY_BASE_URL": _proxy_base_url, "PROXY_LOGOUT_URL": _logout_url}
|
|
|
|
|
|
@app.get("/onboarding/get_token", include_in_schema=False)
|
|
async def onboarding(invite_link: str):
|
|
"""
|
|
- Get the invite link
|
|
- Validate it's still 'valid'
|
|
- Invalidate the link (prevents abuse)
|
|
- Get user from db
|
|
- Pass in user_email if set
|
|
"""
|
|
global prisma_client, master_key
|
|
if master_key is None:
|
|
raise ProxyException(
|
|
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
|
type="auth_error",
|
|
param="master_key",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
### VALIDATE INVITE LINK ###
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
invite_obj = await prisma_client.db.litellm_invitationlink.find_unique(
|
|
where={"id": invite_link}
|
|
)
|
|
if invite_obj is None:
|
|
raise HTTPException(
|
|
status_code=401, detail={"error": "Invitation link does not exist in db."}
|
|
)
|
|
#### CHECK IF EXPIRED
|
|
# Extract the date part from both datetime objects
|
|
utc_now_date = litellm.utils.get_utc_datetime().date()
|
|
expires_at_date = invite_obj.expires_at.date()
|
|
if expires_at_date < utc_now_date:
|
|
raise HTTPException(
|
|
status_code=401, detail={"error": "Invitation link has expired."}
|
|
)
|
|
|
|
#### INVALIDATE LINK
|
|
current_time = litellm.utils.get_utc_datetime()
|
|
|
|
_ = await prisma_client.db.litellm_invitationlink.update(
|
|
where={"id": invite_link},
|
|
data={
|
|
"accepted_at": current_time,
|
|
"updated_at": current_time,
|
|
"is_accepted": True,
|
|
"updated_by": invite_obj.user_id, # type: ignore
|
|
},
|
|
)
|
|
|
|
### GET USER OBJECT ###
|
|
user_obj = await prisma_client.db.litellm_usertable.find_unique(
|
|
where={"user_id": invite_obj.user_id}
|
|
)
|
|
|
|
if user_obj is None:
|
|
raise HTTPException(
|
|
status_code=401, detail={"error": "User does not exist in db."}
|
|
)
|
|
|
|
user_email = user_obj.user_email
|
|
|
|
response = await generate_key_helper_fn(
|
|
request_type="key",
|
|
**{
|
|
"user_role": user_obj.user_role,
|
|
"duration": "2hr",
|
|
"key_max_budget": 5,
|
|
"models": [],
|
|
"aliases": {},
|
|
"config": {},
|
|
"spend": 0,
|
|
"user_id": user_obj.user_id,
|
|
"team_id": "litellm-dashboard",
|
|
}, # type: ignore
|
|
)
|
|
key = response["token"] # type: ignore
|
|
|
|
litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
|
|
if litellm_dashboard_ui.endswith("/"):
|
|
litellm_dashboard_ui += "ui/onboarding"
|
|
else:
|
|
litellm_dashboard_ui += "/ui/onboarding"
|
|
import jwt
|
|
|
|
jwt_token = jwt.encode(
|
|
{
|
|
"user_id": user_obj.user_id,
|
|
"key": key,
|
|
"user_email": user_obj.user_email,
|
|
"user_role": user_obj.user_role,
|
|
"login_method": "username_password",
|
|
"premium_user": premium_user,
|
|
},
|
|
master_key,
|
|
algorithm="HS256",
|
|
)
|
|
|
|
litellm_dashboard_ui += "?token={}&user_email={}".format(jwt_token, user_email)
|
|
return {
|
|
"login_url": litellm_dashboard_ui,
|
|
"token": jwt_token,
|
|
"user_email": user_email,
|
|
}
|
|
|
|
|
|
@app.post("/onboarding/claim_token", include_in_schema=False)
|
|
async def claim_onboarding_link(data: InvitationClaim):
|
|
"""
|
|
Special route. Allows UI link share user to update their password.
|
|
|
|
- Get the invite link
|
|
- Validate it's still 'valid'
|
|
- Check if user within initial session (prevents abuse)
|
|
- Get user from db
|
|
- Update user password
|
|
|
|
This route can only update user password.
|
|
"""
|
|
global prisma_client
|
|
### VALIDATE INVITE LINK ###
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
invite_obj = await prisma_client.db.litellm_invitationlink.find_unique(
|
|
where={"id": data.invitation_link}
|
|
)
|
|
if invite_obj is None:
|
|
raise HTTPException(
|
|
status_code=401, detail={"error": "Invitation link does not exist in db."}
|
|
)
|
|
#### CHECK IF EXPIRED
|
|
# Extract the date part from both datetime objects
|
|
utc_now_date = litellm.utils.get_utc_datetime().date()
|
|
expires_at_date = invite_obj.expires_at.date()
|
|
if expires_at_date < utc_now_date:
|
|
raise HTTPException(
|
|
status_code=401, detail={"error": "Invitation link has expired."}
|
|
)
|
|
|
|
#### CHECK IF CLAIMED
|
|
##### if claimed - accept
|
|
##### if unclaimed - reject
|
|
|
|
if invite_obj.is_accepted is True:
|
|
# this is a valid invite that was accepted
|
|
pass
|
|
else:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={
|
|
"error": "The invitation link was never validated. Please file an issue, if this is not intended - https://github.com/BerriAI/litellm/issues."
|
|
},
|
|
)
|
|
|
|
#### CHECK IF VALID USER ID
|
|
if invite_obj.user_id != data.user_id:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={
|
|
"error": "Invalid invitation link. The user id submitted does not match the user id this link is attached to. Got={}, Expected={}".format(
|
|
data.user_id, invite_obj.user_id
|
|
)
|
|
},
|
|
)
|
|
### UPDATE USER OBJECT ###
|
|
hash_password = hash_token(token=data.password)
|
|
user_obj = await prisma_client.db.litellm_usertable.update(
|
|
where={"user_id": invite_obj.user_id}, data={"password": hash_password}
|
|
)
|
|
|
|
if user_obj is None:
|
|
raise HTTPException(
|
|
status_code=401, detail={"error": "User does not exist in db."}
|
|
)
|
|
|
|
return user_obj
|
|
|
|
|
|
@app.get("/get_image", include_in_schema=False)
|
|
def get_image():
|
|
"""Get logo to show on admin UI"""
|
|
from fastapi.responses import FileResponse
|
|
|
|
# get current_dir
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
default_logo = os.path.join(current_dir, "logo.jpg")
|
|
|
|
logo_path = os.getenv("UI_LOGO_PATH", default_logo)
|
|
verbose_proxy_logger.debug("Reading logo from path: %s", logo_path)
|
|
|
|
# Check if the logo path is an HTTP/HTTPS URL
|
|
if logo_path.startswith(("http://", "https://")):
|
|
# Download the image and cache it
|
|
response = requests.get(logo_path)
|
|
if response.status_code == 200:
|
|
# Save the image to a local file
|
|
cache_path = os.path.join(current_dir, "cached_logo.jpg")
|
|
with open(cache_path, "wb") as f:
|
|
f.write(response.content)
|
|
|
|
# Return the cached image as a FileResponse
|
|
return FileResponse(cache_path, media_type="image/jpeg")
|
|
else:
|
|
# Handle the case when the image cannot be downloaded
|
|
return FileResponse(default_logo, media_type="image/jpeg")
|
|
else:
|
|
# Return the local image file if the logo path is not an HTTP/HTTPS URL
|
|
return FileResponse(logo_path, media_type="image/jpeg")
|
|
|
|
|
|
@app.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
|
async def auth_callback(request: Request):
|
|
"""Verify login"""
|
|
global general_settings, ui_access_mode, premium_user, master_key
|
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
|
# get url from request
|
|
if master_key is None:
|
|
raise ProxyException(
|
|
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
|
type="auth_error",
|
|
param="master_key",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
|
if redirect_url.endswith("/"):
|
|
redirect_url += "sso/callback"
|
|
else:
|
|
redirect_url += "/sso/callback"
|
|
if google_client_id is not None:
|
|
from fastapi_sso.sso.google import GoogleSSO
|
|
|
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
|
|
if google_client_secret is None:
|
|
raise ProxyException(
|
|
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GOOGLE_CLIENT_SECRET",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
google_sso = GoogleSSO(
|
|
client_id=google_client_id,
|
|
redirect_uri=redirect_url,
|
|
client_secret=google_client_secret,
|
|
)
|
|
result = await google_sso.verify_and_process(request)
|
|
elif microsoft_client_id is not None:
|
|
from fastapi_sso.sso.microsoft import MicrosoftSSO
|
|
|
|
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
|
|
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
|
|
if microsoft_client_secret is None:
|
|
raise ProxyException(
|
|
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="MICROSOFT_CLIENT_SECRET",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
if microsoft_tenant is None:
|
|
raise ProxyException(
|
|
message="MICROSOFT_TENANT not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="MICROSOFT_TENANT",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
microsoft_sso = MicrosoftSSO(
|
|
client_id=microsoft_client_id,
|
|
client_secret=microsoft_client_secret,
|
|
tenant=microsoft_tenant,
|
|
redirect_uri=redirect_url,
|
|
allow_insecure_http=True,
|
|
)
|
|
result = await microsoft_sso.verify_and_process(request)
|
|
elif generic_client_id is not None:
|
|
# make generic sso provider
|
|
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
|
|
|
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
|
generic_authorization_endpoint = os.getenv(
|
|
"GENERIC_AUTHORIZATION_ENDPOINT", None
|
|
)
|
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
|
generic_include_client_id = (
|
|
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
|
)
|
|
if generic_client_secret is None:
|
|
raise ProxyException(
|
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GENERIC_CLIENT_SECRET",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
if generic_authorization_endpoint is None:
|
|
raise ProxyException(
|
|
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GENERIC_AUTHORIZATION_ENDPOINT",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
if generic_token_endpoint is None:
|
|
raise ProxyException(
|
|
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GENERIC_TOKEN_ENDPOINT",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
if generic_userinfo_endpoint is None:
|
|
raise ProxyException(
|
|
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
|
|
type="auth_error",
|
|
param="GENERIC_USERINFO_ENDPOINT",
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
|
|
)
|
|
verbose_proxy_logger.debug(
|
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
|
)
|
|
|
|
generic_user_id_attribute_name = os.getenv(
|
|
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
|
)
|
|
generic_user_display_name_attribute_name = os.getenv(
|
|
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
|
)
|
|
generic_user_email_attribute_name = os.getenv(
|
|
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
|
)
|
|
generic_user_role_attribute_name = os.getenv(
|
|
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
|
)
|
|
generic_user_first_name_attribute_name = os.getenv(
|
|
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
|
)
|
|
generic_user_last_name_attribute_name = os.getenv(
|
|
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
|
)
|
|
|
|
verbose_proxy_logger.debug(
|
|
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
|
)
|
|
|
|
discovery = DiscoveryDocument(
|
|
authorization_endpoint=generic_authorization_endpoint,
|
|
token_endpoint=generic_token_endpoint,
|
|
userinfo_endpoint=generic_userinfo_endpoint,
|
|
)
|
|
|
|
def response_convertor(response, client):
|
|
return OpenID(
|
|
id=response.get(generic_user_id_attribute_name),
|
|
display_name=response.get(generic_user_display_name_attribute_name),
|
|
email=response.get(generic_user_email_attribute_name),
|
|
first_name=response.get(generic_user_first_name_attribute_name),
|
|
last_name=response.get(generic_user_last_name_attribute_name),
|
|
)
|
|
|
|
SSOProvider = create_provider(
|
|
name="oidc",
|
|
discovery_document=discovery,
|
|
response_convertor=response_convertor,
|
|
)
|
|
generic_sso = SSOProvider(
|
|
client_id=generic_client_id,
|
|
client_secret=generic_client_secret,
|
|
redirect_uri=redirect_url,
|
|
allow_insecure_http=True,
|
|
scope=generic_scope,
|
|
)
|
|
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
|
result = await generic_sso.verify_and_process(
|
|
request, params={"include_client_id": generic_include_client_id}
|
|
)
|
|
verbose_proxy_logger.debug("generic result: %s", result)
|
|
|
|
# User is Authe'd in - generate key for the UI to access Proxy
|
|
user_email = getattr(result, "email", None)
|
|
user_id = getattr(result, "id", None)
|
|
|
|
# generic client id
|
|
if generic_client_id is not None:
|
|
user_id = getattr(result, "id", None)
|
|
user_email = getattr(result, "email", None)
|
|
user_role = getattr(result, generic_user_role_attribute_name, None)
|
|
|
|
if user_id is None:
|
|
_first_name = getattr(result, "first_name", "") or ""
|
|
_last_name = getattr(result, "last_name", "") or ""
|
|
user_id = _first_name + _last_name
|
|
|
|
user_info = None
|
|
user_id_models: List = []
|
|
|
|
# User might not be already created on first generation of key
|
|
# But if it is, we want their models preferences
|
|
default_ui_key_values = {
|
|
"duration": "2hr",
|
|
"key_max_budget": 0.01,
|
|
"aliases": {},
|
|
"config": {},
|
|
"spend": 0,
|
|
"team_id": "litellm-dashboard",
|
|
}
|
|
user_defined_values = {
|
|
"models": user_id_models,
|
|
"user_id": user_id,
|
|
"user_email": user_email,
|
|
}
|
|
_user_id_from_sso = user_id
|
|
try:
|
|
user_role = None
|
|
if prisma_client is not None:
|
|
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
|
verbose_proxy_logger.debug(
|
|
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
|
|
)
|
|
if user_info is not None:
|
|
user_defined_values = {
|
|
"models": getattr(user_info, "models", []),
|
|
"user_id": getattr(user_info, "user_id", user_id),
|
|
"user_email": getattr(user_info, "user_id", user_email),
|
|
"user_role": getattr(user_info, "user_role", None),
|
|
}
|
|
user_role = getattr(user_info, "user_role", None)
|
|
|
|
## check if user-email in db ##
|
|
user_info = await prisma_client.db.litellm_usertable.find_first(
|
|
where={"user_email": user_email}
|
|
)
|
|
if user_info is not None:
|
|
user_defined_values = {
|
|
"models": getattr(user_info, "models", user_id_models),
|
|
"user_id": user_id,
|
|
"user_email": getattr(user_info, "user_id", user_email),
|
|
"user_role": getattr(user_info, "user_role", None),
|
|
}
|
|
user_role = getattr(user_info, "user_role", None)
|
|
|
|
# update id
|
|
await prisma_client.db.litellm_usertable.update_many(
|
|
where={"user_email": user_email}, data={"user_id": user_id} # type: ignore
|
|
)
|
|
elif litellm.default_user_params is not None and isinstance(
|
|
litellm.default_user_params, dict
|
|
):
|
|
user_defined_values = {
|
|
"models": litellm.default_user_params.get("models", user_id_models),
|
|
"user_id": litellm.default_user_params.get("user_id", user_id),
|
|
"user_email": litellm.default_user_params.get(
|
|
"user_email", user_email
|
|
),
|
|
}
|
|
except Exception as e:
|
|
pass
|
|
|
|
verbose_proxy_logger.info(
|
|
f"user_defined_values for creating ui key: {user_defined_values}"
|
|
)
|
|
|
|
default_ui_key_values.update(user_defined_values)
|
|
default_ui_key_values["request_type"] = "key"
|
|
response = await generate_key_helper_fn(
|
|
**default_ui_key_values, # type: ignore
|
|
)
|
|
key = response["token"] # type: ignore
|
|
user_id = response["user_id"] # type: ignore
|
|
|
|
# This should always be true
|
|
# User_id on SSO == user_id in the LiteLLM_VerificationToken Table
|
|
assert user_id == _user_id_from_sso
|
|
litellm_dashboard_ui = "/ui/"
|
|
user_role = user_role or "app_owner"
|
|
if (
|
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
|
and os.environ["PROXY_ADMIN_ID"] == user_id
|
|
):
|
|
# checks if user is admin
|
|
user_role = "app_admin"
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
|
|
)
|
|
## CHECK IF ROLE ALLOWED TO USE PROXY ##
|
|
if ui_access_mode == "admin_only" and "admin" not in user_role:
|
|
verbose_proxy_logger.debug("EXCEPTION RAISED")
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={
|
|
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
|
|
},
|
|
)
|
|
|
|
import jwt
|
|
|
|
jwt_token = jwt.encode(
|
|
{
|
|
"user_id": user_id,
|
|
"key": key,
|
|
"user_email": user_email,
|
|
"user_role": user_role,
|
|
"login_method": "sso",
|
|
"premium_user": premium_user,
|
|
},
|
|
master_key,
|
|
algorithm="HS256",
|
|
)
|
|
litellm_dashboard_ui += "?userID=" + user_id
|
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
|
redirect_response.set_cookie(key="token", value=jwt_token)
|
|
return redirect_response
|
|
|
|
|
|
#### INVITATION MANAGEMENT ####
|
|
|
|
|
|
@router.post(
|
|
"/invitation/new",
|
|
tags=["Invite Links"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=InvitationModel,
|
|
include_in_schema=False,
|
|
)
|
|
async def new_invitation(
|
|
data: InvitationNew, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)
|
|
):
|
|
"""
|
|
Allow admin to create invite links, to onboard new users to Admin UI.
|
|
|
|
```
|
|
curl -X POST 'http://localhost:4000/invitation/new' \
|
|
-H 'Content-Type: application/json' \
|
|
-D '{
|
|
"user_id": "1234" // 👈 id of user in 'LiteLLM_UserTable'
|
|
}'
|
|
```
|
|
"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
current_time = litellm.utils.get_utc_datetime()
|
|
expires_at = current_time + timedelta(days=7)
|
|
|
|
try:
|
|
response = await prisma_client.db.litellm_invitationlink.create(
|
|
data={
|
|
"user_id": data.user_id,
|
|
"created_at": current_time,
|
|
"expires_at": expires_at,
|
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
"updated_at": current_time,
|
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
|
} # type: ignore
|
|
)
|
|
except Exception as e:
|
|
if "Foreign key constraint failed on the field" in str(e):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
|
|
},
|
|
)
|
|
return response
|
|
|
|
|
|
@router.get(
|
|
"/invitation/info",
|
|
tags=["Invite Links"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=InvitationModel,
|
|
include_in_schema=False,
|
|
)
|
|
async def invitation_info(
|
|
invitation_id: str, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)
|
|
):
|
|
"""
|
|
Allow admin to create invite links, to onboard new users to Admin UI.
|
|
|
|
```
|
|
curl -X POST 'http://localhost:4000/invitation/new' \
|
|
-H 'Content-Type: application/json' \
|
|
-D '{
|
|
"user_id": "1234" // 👈 id of user in 'LiteLLM_UserTable'
|
|
}'
|
|
```
|
|
"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
response = await prisma_client.db.litellm_invitationlink.find_unique(
|
|
where={"id": invitation_id}
|
|
)
|
|
|
|
if response is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Invitation id does not exist in the database."},
|
|
)
|
|
return response
|
|
|
|
|
|
@router.post(
|
|
"/invitation/update",
|
|
tags=["Invite Links"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=InvitationModel,
|
|
include_in_schema=False,
|
|
)
|
|
async def invitation_update(
|
|
data: InvitationUpdate,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Update when invitation is accepted
|
|
|
|
```
|
|
curl -X POST 'http://localhost:4000/invitation/update' \
|
|
-H 'Content-Type: application/json' \
|
|
-D '{
|
|
"invitation_id": "1234" // 👈 id of invitation in 'LiteLLM_InvitationTable'
|
|
"is_accepted": True // when invitation is accepted
|
|
}'
|
|
```
|
|
"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_id is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "Unable to identify user id. Received={}".format(
|
|
user_api_key_dict.user_id
|
|
)
|
|
},
|
|
)
|
|
|
|
current_time = litellm.utils.get_utc_datetime()
|
|
response = await prisma_client.db.litellm_invitationlink.update(
|
|
where={"id": data.invitation_id},
|
|
data={
|
|
"id": data.invitation_id,
|
|
"is_accepted": data.is_accepted,
|
|
"accepted_at": current_time,
|
|
"updated_at": current_time,
|
|
"updated_by": user_api_key_dict.user_id, # type: ignore
|
|
},
|
|
)
|
|
|
|
if response is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Invitation id does not exist in the database."},
|
|
)
|
|
return response
|
|
|
|
|
|
@router.post(
|
|
"/invitation/delete",
|
|
tags=["Invite Links"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=InvitationModel,
|
|
include_in_schema=False,
|
|
)
|
|
async def invitation_delete(
|
|
data: InvitationDelete,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Delete invitation link
|
|
|
|
```
|
|
curl -X POST 'http://localhost:4000/invitation/delete' \
|
|
-H 'Content-Type: application/json' \
|
|
-D '{
|
|
"invitation_id": "1234" // 👈 id of invitation in 'LiteLLM_InvitationTable'
|
|
}'
|
|
```
|
|
"""
|
|
global prisma_client
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
response = await prisma_client.db.litellm_invitationlink.delete(
|
|
where={"id": data.invitation_id}
|
|
)
|
|
|
|
if response is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Invitation id does not exist in the database."},
|
|
)
|
|
return response
|
|
|
|
|
|
#### CONFIG MANAGEMENT ####
|
|
@router.post(
|
|
"/config/update",
|
|
tags=["config.yaml"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def update_config(config_info: ConfigYAML):
|
|
"""
|
|
For Admin UI - allows admin to update config via UI
|
|
|
|
Currently supports modifying General Settings + LiteLLM settings
|
|
"""
|
|
global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj, master_key, prisma_client
|
|
try:
|
|
import base64
|
|
|
|
"""
|
|
- Update the ConfigTable DB
|
|
- Run 'add_deployment'
|
|
"""
|
|
if prisma_client is None:
|
|
raise Exception("No DB Connected")
|
|
|
|
if store_model_in_db is not True:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
|
},
|
|
)
|
|
|
|
updated_settings = config_info.json(exclude_none=True)
|
|
updated_settings = prisma_client.jsonify_object(updated_settings)
|
|
for k, v in updated_settings.items():
|
|
if k == "router_settings":
|
|
await prisma_client.db.litellm_config.upsert(
|
|
where={"param_name": k},
|
|
data={
|
|
"create": {"param_name": k, "param_value": v},
|
|
"update": {"param_value": v},
|
|
},
|
|
)
|
|
|
|
### OLD LOGIC [TODO] MOVE TO DB ###
|
|
|
|
import base64
|
|
|
|
# Load existing config
|
|
config = await proxy_config.get_config()
|
|
verbose_proxy_logger.debug("Loaded config: %s", config)
|
|
|
|
# update the general settings
|
|
if config_info.general_settings is not None:
|
|
config.setdefault("general_settings", {})
|
|
updated_general_settings = config_info.general_settings.dict(
|
|
exclude_none=True
|
|
)
|
|
|
|
_existing_settings = config["general_settings"]
|
|
for k, v in updated_general_settings.items():
|
|
# overwrite existing settings with updated values
|
|
if k == "alert_to_webhook_url":
|
|
# check if slack is already enabled. if not, enable it
|
|
if "alerting" not in _existing_settings:
|
|
_existing_settings["alerting"].append("slack")
|
|
elif isinstance(_existing_settings["alerting"], list):
|
|
if "slack" not in _existing_settings["alerting"]:
|
|
_existing_settings["alerting"].append("slack")
|
|
_existing_settings[k] = v
|
|
config["general_settings"] = _existing_settings
|
|
|
|
if config_info.environment_variables is not None:
|
|
config.setdefault("environment_variables", {})
|
|
_updated_environment_variables = config_info.environment_variables
|
|
|
|
# encrypt updated_environment_variables #
|
|
for k, v in _updated_environment_variables.items():
|
|
if isinstance(v, str):
|
|
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
|
|
_updated_environment_variables[k] = base64.b64encode(
|
|
encrypted_value
|
|
).decode("utf-8")
|
|
|
|
_existing_env_variables = config["environment_variables"]
|
|
|
|
for k, v in _updated_environment_variables.items():
|
|
# overwrite existing env variables with updated values
|
|
_existing_env_variables[k] = _updated_environment_variables[k]
|
|
|
|
# update the litellm settings
|
|
if config_info.litellm_settings is not None:
|
|
config.setdefault("litellm_settings", {})
|
|
updated_litellm_settings = config_info.litellm_settings
|
|
config["litellm_settings"] = {
|
|
**updated_litellm_settings,
|
|
**config["litellm_settings"],
|
|
}
|
|
|
|
# if litellm.success_callback in updated_litellm_settings and config["litellm_settings"]
|
|
if (
|
|
"success_callback" in updated_litellm_settings
|
|
and "success_callback" in config["litellm_settings"]
|
|
):
|
|
|
|
# check both success callback are lists
|
|
if isinstance(
|
|
config["litellm_settings"]["success_callback"], list
|
|
) and isinstance(updated_litellm_settings["success_callback"], list):
|
|
combined_success_callback = (
|
|
config["litellm_settings"]["success_callback"]
|
|
+ updated_litellm_settings["success_callback"]
|
|
)
|
|
combined_success_callback = list(set(combined_success_callback))
|
|
config["litellm_settings"][
|
|
"success_callback"
|
|
] = combined_success_callback
|
|
|
|
# Save the updated config
|
|
await proxy_config.save_config(new_config=config)
|
|
|
|
await proxy_config.add_deployment(
|
|
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
|
)
|
|
|
|
return {"message": "Config updated successfully"}
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.update_config(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Authentication Error, " + str(e),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
|
|
### CONFIG GENERAL SETTINGS
|
|
"""
|
|
- Update config settings
|
|
- Get config settings
|
|
|
|
Keep it more precise, to prevent overwrite other values unintentially
|
|
"""
|
|
|
|
|
|
@router.post(
|
|
"/config/field/update",
|
|
tags=["config.yaml"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def update_config_general_settings(
|
|
data: ConfigFieldUpdate,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Update a specific field in litellm general settings
|
|
"""
|
|
global prisma_client
|
|
## VALIDATION ##
|
|
"""
|
|
- Check if prisma_client is None
|
|
- Check if user allowed to call this endpoint (admin-only)
|
|
- Check if param in general settings
|
|
- Check if config value is valid type
|
|
"""
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
|
)
|
|
|
|
if data.field_name not in ConfigGeneralSettings.model_fields:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Invalid field={} passed in.".format(data.field_name)},
|
|
)
|
|
|
|
try:
|
|
cgs = ConfigGeneralSettings(**{data.field_name: data.field_value})
|
|
except:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "Invalid type of field value={} passed in.".format(
|
|
type(data.field_value),
|
|
)
|
|
},
|
|
)
|
|
|
|
## get general settings from db
|
|
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
|
where={"param_name": "general_settings"}
|
|
)
|
|
### update value
|
|
|
|
if db_general_settings is None or db_general_settings.param_value is None:
|
|
general_settings = {}
|
|
else:
|
|
general_settings = dict(db_general_settings.param_value)
|
|
|
|
## update db
|
|
|
|
general_settings[data.field_name] = data.field_value
|
|
|
|
response = await prisma_client.db.litellm_config.upsert(
|
|
where={"param_name": "general_settings"},
|
|
data={
|
|
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
|
|
"update": {"param_value": json.dumps(general_settings)}, # type: ignore
|
|
},
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@router.get(
|
|
"/config/field/info",
|
|
tags=["config.yaml"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=ConfigFieldInfo,
|
|
include_in_schema=False,
|
|
)
|
|
async def get_config_general_settings(
|
|
field_name: str,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
global prisma_client
|
|
|
|
## VALIDATION ##
|
|
"""
|
|
- Check if prisma_client is None
|
|
- Check if user allowed to call this endpoint (admin-only)
|
|
- Check if param in general settings
|
|
"""
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
|
)
|
|
|
|
if field_name not in ConfigGeneralSettings.model_fields:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Invalid field={} passed in.".format(field_name)},
|
|
)
|
|
|
|
## get general settings from db
|
|
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
|
where={"param_name": "general_settings"}
|
|
)
|
|
### pop the value
|
|
|
|
if db_general_settings is None or db_general_settings.param_value is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Field name={} not in DB".format(field_name)},
|
|
)
|
|
else:
|
|
general_settings = dict(db_general_settings.param_value)
|
|
|
|
if field_name in general_settings:
|
|
return ConfigFieldInfo(
|
|
field_name=field_name, field_value=general_settings[field_name]
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Field name={} not in DB".format(field_name)},
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/config/list",
|
|
tags=["config.yaml"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def get_config_list(
|
|
config_type: Literal["general_settings"],
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
) -> List[ConfigList]:
|
|
"""
|
|
List the available fields + current values for a given type of setting (currently just 'general_settings'user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),)
|
|
"""
|
|
global prisma_client, general_settings
|
|
|
|
## VALIDATION ##
|
|
"""
|
|
- Check if prisma_client is None
|
|
- Check if user allowed to call this endpoint (admin-only)
|
|
- Check if param in general settings
|
|
"""
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
## get general settings from db
|
|
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
|
where={"param_name": "general_settings"}
|
|
)
|
|
|
|
if db_general_settings is not None and db_general_settings.param_value is not None:
|
|
db_general_settings_dict = dict(db_general_settings.param_value)
|
|
else:
|
|
db_general_settings_dict = {}
|
|
|
|
allowed_args = {
|
|
"max_parallel_requests": {"type": "Integer"},
|
|
"global_max_parallel_requests": {"type": "Integer"},
|
|
}
|
|
|
|
return_val = []
|
|
|
|
for field_name, field_info in ConfigGeneralSettings.model_fields.items():
|
|
if field_name in allowed_args:
|
|
|
|
_stored_in_db = None
|
|
if field_name in db_general_settings_dict:
|
|
_stored_in_db = True
|
|
elif field_name in general_settings:
|
|
_stored_in_db = False
|
|
|
|
_response_obj = ConfigList(
|
|
field_name=field_name,
|
|
field_type=allowed_args[field_name]["type"],
|
|
field_description=field_info.description or "",
|
|
field_value=general_settings.get(field_name, None),
|
|
stored_in_db=_stored_in_db,
|
|
field_default_value=field_info.default,
|
|
)
|
|
return_val.append(_response_obj)
|
|
|
|
return return_val
|
|
|
|
|
|
@router.post(
|
|
"/config/field/delete",
|
|
tags=["config.yaml"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def delete_config_general_settings(
|
|
data: ConfigFieldDelete,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Delete the db value of this field in litellm general settings. Resets it to it's initial default value on litellm.
|
|
"""
|
|
global prisma_client
|
|
## VALIDATION ##
|
|
"""
|
|
- Check if prisma_client is None
|
|
- Check if user allowed to call this endpoint (admin-only)
|
|
- Check if param in general settings
|
|
"""
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": "{}, your role={}".format(
|
|
CommonProxyErrors.not_allowed_access.value,
|
|
user_api_key_dict.user_role,
|
|
)
|
|
},
|
|
)
|
|
|
|
if data.field_name not in ConfigGeneralSettings.model_fields:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Invalid field={} passed in.".format(data.field_name)},
|
|
)
|
|
|
|
## get general settings from db
|
|
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
|
where={"param_name": "general_settings"}
|
|
)
|
|
### pop the value
|
|
|
|
if db_general_settings is None or db_general_settings.param_value is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"error": "Field name={} not in config".format(data.field_name)},
|
|
)
|
|
else:
|
|
general_settings = dict(db_general_settings.param_value)
|
|
|
|
## update db
|
|
|
|
general_settings.pop(data.field_name, None)
|
|
|
|
response = await prisma_client.db.litellm_config.upsert(
|
|
where={"param_name": "general_settings"},
|
|
data={
|
|
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
|
|
"update": {"param_value": json.dumps(general_settings)}, # type: ignore
|
|
},
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@router.get(
|
|
"/get/config/callbacks",
|
|
tags=["config.yaml"],
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def get_config():
|
|
"""
|
|
For Admin UI - allows admin to view config via UI
|
|
# return the callbacks and the env variables for the callback
|
|
|
|
"""
|
|
global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj, master_key
|
|
try:
|
|
import base64
|
|
|
|
all_available_callbacks = AllCallbacks()
|
|
|
|
config_data = await proxy_config.get_config()
|
|
_litellm_settings = config_data.get("litellm_settings", {})
|
|
_general_settings = config_data.get("general_settings", {})
|
|
environment_variables = config_data.get("environment_variables", {})
|
|
|
|
# check if "langfuse" in litellm_settings
|
|
_success_callbacks = _litellm_settings.get("success_callback", [])
|
|
_data_to_return = []
|
|
"""
|
|
[
|
|
{
|
|
"name": "langfuse",
|
|
"variables": {
|
|
"LANGFUSE_PUB_KEY": "value",
|
|
"LANGFUSE_SECRET_KEY": "value",
|
|
"LANGFUSE_HOST": "value"
|
|
},
|
|
}
|
|
]
|
|
|
|
"""
|
|
for _callback in _success_callbacks:
|
|
if _callback == "openmeter":
|
|
env_vars = [
|
|
"OPENMETER_API_KEY",
|
|
]
|
|
env_vars_dict = {}
|
|
for _var in env_vars:
|
|
env_variable = environment_variables.get(_var, None)
|
|
if env_variable is None:
|
|
env_vars_dict[_var] = None
|
|
else:
|
|
# decode + decrypt the value
|
|
decoded_b64 = base64.b64decode(env_variable)
|
|
_decrypted_value = decrypt_value(
|
|
value=decoded_b64, master_key=master_key
|
|
)
|
|
env_vars_dict[_var] = _decrypted_value
|
|
|
|
_data_to_return.append({"name": _callback, "variables": env_vars_dict})
|
|
elif _callback == "langfuse":
|
|
_langfuse_vars = [
|
|
"LANGFUSE_PUBLIC_KEY",
|
|
"LANGFUSE_SECRET_KEY",
|
|
"LANGFUSE_HOST",
|
|
]
|
|
_langfuse_env_vars = {}
|
|
for _var in _langfuse_vars:
|
|
env_variable = environment_variables.get(_var, None)
|
|
if env_variable is None:
|
|
_langfuse_env_vars[_var] = None
|
|
else:
|
|
# decode + decrypt the value
|
|
decoded_b64 = base64.b64decode(env_variable)
|
|
_decrypted_value = decrypt_value(
|
|
value=decoded_b64, master_key=master_key
|
|
)
|
|
_langfuse_env_vars[_var] = _decrypted_value
|
|
|
|
_data_to_return.append(
|
|
{"name": _callback, "variables": _langfuse_env_vars}
|
|
)
|
|
|
|
# Check if slack alerting is on
|
|
_alerting = _general_settings.get("alerting", [])
|
|
alerting_data = []
|
|
if "slack" in _alerting:
|
|
_slack_vars = [
|
|
"SLACK_WEBHOOK_URL",
|
|
]
|
|
_slack_env_vars = {}
|
|
for _var in _slack_vars:
|
|
env_variable = environment_variables.get(_var, None)
|
|
if env_variable is None:
|
|
_value = os.getenv("SLACK_WEBHOOK_URL", None)
|
|
_slack_env_vars[_var] = _value
|
|
else:
|
|
# decode + decrypt the value
|
|
decoded_b64 = base64.b64decode(env_variable)
|
|
_decrypted_value = decrypt_value(
|
|
value=decoded_b64, master_key=master_key
|
|
)
|
|
_slack_env_vars[_var] = _decrypted_value
|
|
|
|
_alerting_types = proxy_logging_obj.slack_alerting_instance.alert_types
|
|
_all_alert_types = (
|
|
proxy_logging_obj.slack_alerting_instance._all_possible_alert_types()
|
|
)
|
|
_alerts_to_webhook = (
|
|
proxy_logging_obj.slack_alerting_instance.alert_to_webhook_url
|
|
)
|
|
alerting_data.append(
|
|
{
|
|
"name": "slack",
|
|
"variables": _slack_env_vars,
|
|
"active_alerts": _alerting_types,
|
|
"alerts_to_webhook": _alerts_to_webhook,
|
|
}
|
|
)
|
|
# pass email alerting vars
|
|
_email_vars = [
|
|
"SMTP_HOST",
|
|
"SMTP_PORT",
|
|
"SMTP_USERNAME",
|
|
"SMTP_PASSWORD",
|
|
"SMTP_SENDER_EMAIL",
|
|
"TEST_EMAIL_ADDRESS",
|
|
"EMAIL_LOGO_URL",
|
|
"EMAIL_SUPPORT_CONTACT",
|
|
]
|
|
_email_env_vars = {}
|
|
for _var in _email_vars:
|
|
env_variable = environment_variables.get(_var, None)
|
|
if env_variable is None:
|
|
_email_env_vars[_var] = None
|
|
else:
|
|
# decode + decrypt the value
|
|
decoded_b64 = base64.b64decode(env_variable)
|
|
_decrypted_value = decrypt_value(
|
|
value=decoded_b64, master_key=master_key
|
|
)
|
|
_email_env_vars[_var] = _decrypted_value
|
|
|
|
alerting_data.append(
|
|
{
|
|
"name": "email",
|
|
"variables": _email_env_vars,
|
|
}
|
|
)
|
|
|
|
if llm_router is None:
|
|
_router_settings = {}
|
|
else:
|
|
_router_settings = llm_router.get_settings()
|
|
|
|
return {
|
|
"status": "success",
|
|
"callbacks": _data_to_return,
|
|
"alerts": alerting_data,
|
|
"router_settings": _router_settings,
|
|
"available_callbacks": all_available_callbacks,
|
|
}
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.proxy_server.get_config(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
if isinstance(e, HTTPException):
|
|
raise ProxyException(
|
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
)
|
|
elif isinstance(e, ProxyException):
|
|
raise e
|
|
raise ProxyException(
|
|
message="Authentication Error, " + str(e),
|
|
type="auth_error",
|
|
param=getattr(e, "param", "None"),
|
|
code=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/config/yaml",
|
|
tags=["config.yaml"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def config_yaml_endpoint(config_info: ConfigYAML):
|
|
"""
|
|
This is a mock endpoint, to show what you can set in config.yaml details in the Swagger UI.
|
|
|
|
Parameters:
|
|
|
|
The config.yaml object has the following attributes:
|
|
- **model_list**: *Optional[List[ModelParams]]* - A list of supported models on the server, along with model-specific configurations. ModelParams includes "model_name" (name of the model), "litellm_params" (litellm-specific parameters for the model), and "model_info" (additional info about the model such as id, mode, cost per token, etc).
|
|
|
|
- **litellm_settings**: *Optional[dict]*: Settings for the litellm module. You can specify multiple properties like "drop_params", "set_verbose", "api_base", "cache".
|
|
|
|
- **general_settings**: *Optional[ConfigGeneralSettings]*: General settings for the server like "completion_model" (default model for chat completion calls), "use_azure_key_vault" (option to load keys from azure key vault), "master_key" (key required for all calls to proxy), and others.
|
|
|
|
Please, refer to each class's description for a better understanding of the specific attributes within them.
|
|
|
|
Note: This is a mock endpoint primarily meant for demonstration purposes, and does not actually provide or change any configurations.
|
|
"""
|
|
return {"hello": "world"}
|
|
|
|
|
|
@router.get(
|
|
"/get/litellm_model_cost_map",
|
|
include_in_schema=False,
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def get_litellm_model_cost_map():
|
|
try:
|
|
_model_cost_map = litellm.model_cost
|
|
return _model_cost_map
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Internal Server Error ({str(e)})",
|
|
)
|
|
|
|
|
|
@router.get("/", dependencies=[Depends(user_api_key_auth)])
|
|
async def home(request: Request):
|
|
return "LiteLLM: RUNNING"
|
|
|
|
|
|
@router.get("/routes", dependencies=[Depends(user_api_key_auth)])
|
|
async def get_routes():
|
|
"""
|
|
Get a list of available routes in the FastAPI application.
|
|
"""
|
|
routes = []
|
|
for route in app.routes:
|
|
route_info = {
|
|
"path": getattr(route, "path", None),
|
|
"methods": getattr(route, "methods", None),
|
|
"name": getattr(route, "name", None),
|
|
"endpoint": (
|
|
getattr(route, "endpoint", None).__name__
|
|
if getattr(route, "endpoint", None)
|
|
else None
|
|
),
|
|
}
|
|
routes.append(route_info)
|
|
|
|
return {"routes": routes}
|
|
|
|
|
|
#### TEST ENDPOINTS ####
|
|
@router.get(
|
|
"/token/generate",
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
include_in_schema=False,
|
|
)
|
|
async def token_generate():
|
|
"""
|
|
Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
|
"""
|
|
# Initialize AuthJWTSSO with your OpenID Provider configuration
|
|
from fastapi_sso import AuthJWTSSO
|
|
|
|
auth_jwt_sso = AuthJWTSSO(
|
|
issuer=os.getenv("OPENID_BASE_URL"),
|
|
client_id=os.getenv("OPENID_CLIENT_ID"),
|
|
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
|
scopes=["litellm_proxy_admin"],
|
|
)
|
|
|
|
token = auth_jwt_sso.create_access_token()
|
|
|
|
return {"token": token}
|
|
|
|
|
|
@router.on_event("shutdown")
|
|
async def shutdown_event():
|
|
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
|
|
verbose_proxy_logger.info("Shutting down LiteLLM Proxy Server")
|
|
if prisma_client:
|
|
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
|
await prisma_client.disconnect()
|
|
|
|
if litellm.cache is not None:
|
|
await litellm.cache.disconnect()
|
|
|
|
await jwt_handler.close()
|
|
|
|
if db_writer_client is not None:
|
|
await db_writer_client.close()
|
|
|
|
# flush remaining langfuse logs
|
|
if "langfuse" in litellm.success_callback:
|
|
try:
|
|
# flush langfuse logs on shutdow
|
|
from litellm.utils import langFuseLogger
|
|
|
|
langFuseLogger.Langfuse.flush()
|
|
except:
|
|
# [DO NOT BLOCK shutdown events for this]
|
|
pass
|
|
|
|
## RESET CUSTOM VARIABLES ##
|
|
cleanup_router_config_variables()
|
|
|
|
|
|
def cleanup_router_config_variables():
|
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
|
|
|
|
# Set all variables to None
|
|
master_key = None
|
|
user_config_file_path = None
|
|
otel_logging = None
|
|
user_custom_auth = None
|
|
user_custom_auth_path = None
|
|
user_custom_key_generate = None
|
|
use_background_health_checks = None
|
|
health_check_interval = None
|
|
prisma_client = None
|
|
custom_db_client = None
|
|
|
|
|
|
app.include_router(router)
|
|
app.include_router(health_router)
|
|
app.include_router(key_management_router)
|
|
app.include_router(internal_user_router)
|
|
app.include_router(team_router)
|
|
app.include_router(spend_management_router)
|
|
app.include_router(caching_router)
|
|
app.include_router(analytics_router)
|
|
app.include_router(debugging_endpoints_router)
|