fix(proxy_server.py): fix /key/generate post endpoint

This commit is contained in:
Krrish Dholakia 2023-12-04 10:43:42 -08:00
parent d7d8c5f6e6
commit 63e55f1865
6 changed files with 115 additions and 27 deletions

View file

@ -213,11 +213,11 @@ class GenerateKeyRequest(BaseModel):
aliases: dict = {} aliases: dict = {}
config: dict = {} config: dict = {}
spend: int = 0 spend: int = 0
user_id: Optional[str] user_id: Optional[str] = None
class GenerateKeyResponse(BaseModel): class GenerateKeyResponse(BaseModel):
key: str key: str
expires: str expires: datetime
user_id: str user_id: str
class _DeleteKeyObject(BaseModel): class _DeleteKeyObject(BaseModel):
@ -277,6 +277,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
"api_key": None "api_key": None
} }
try: try:
if api_key is None:
raise Exception("No api key passed in.")
route = request.url.path route = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
@ -491,8 +493,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
## PRINT YAML FOR CONFIRMING IT WORKS ## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config) printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None) printed_yaml.pop("environment_variables", None)
for model in printed_yaml["model_list"]: if "model_list" in printed_yaml:
model["litellm_params"].pop("api_key", None) for model in printed_yaml["model_list"]:
model["litellm_params"].pop("api_key", None)
print(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}") print(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
@ -507,22 +510,24 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if general_settings is None: if general_settings is None:
general_settings = {} general_settings = {}
if general_settings: if general_settings:
### MASTER KEY ### ### LOAD FROM AZURE KEY VAULT ###
master_key = general_settings.get("master_key", None) use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
if master_key and master_key.startswith("os.environ/"): load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
master_key_env_name = master_key.replace("os.environ/", "")
master_key = os.getenv(master_key_env_name)
### CONNECT TO DATABASE ### ### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
database_url = litellm.get_secret(database_url)
prisma_setup(database_url=database_url) prisma_setup(database_url=database_url)
## COST TRACKING ## ## COST TRACKING ##
cost_tracking() cost_tracking()
### START REDIS QUEUE ### ### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False) use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue) celery_setup(use_queue=use_queue)
### LOAD FROM AZURE KEY VAULT ### ### MASTER KEY ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False) master_key = general_settings.get("master_key", None)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
#### OpenTelemetry Logging (OTEL) ######## #### OpenTelemetry Logging (OTEL) ########
otel_logging = general_settings.get("otel", False) otel_logging = general_settings.get("otel", False)
@ -540,9 +545,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"{blue_color_code}\nSetting Cache on Proxy") print(f"{blue_color_code}\nSetting Cache on Proxy")
from litellm.caching import Cache from litellm.caching import Cache
cache_type = value["type"] cache_type = value["type"]
cache_host = os.environ.get("REDIS_HOST") cache_host = litellm.get_secret("REDIS_HOST")
cache_port = os.environ.get("REDIS_PORT") cache_port = litellm.get_secret("REDIS_PORT")
cache_password = os.environ.get("REDIS_PASSWORD") cache_password = litellm.get_secret("REDIS_PASSWORD")
# Assuming cache_type, cache_host, cache_port, and cache_password are strings # Assuming cache_type, cache_host, cache_port, and cache_password are strings
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}") print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
@ -794,12 +799,14 @@ def litellm_completion(*args, **kwargs):
return StreamingResponse(data_generator(response), media_type='text/event-stream') return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response return response
@app.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key global prisma_client, master_key
import json import json
worker_config = json.loads(os.getenv("WORKER_CONFIG")) worker_config = json.loads(os.getenv("WORKER_CONFIG"))
print(f"worker_config: {worker_config}")
initialize(**worker_config) initialize(**worker_config)
print(f"prisma client - {prisma_client}")
if prisma_client: if prisma_client:
await prisma_client.connect() await prisma_client.connect()
@ -807,7 +814,7 @@ async def startup_event():
# add master key to db # add master key to db
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key) await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
@app.on_event("shutdown") @router.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
global prisma_client global prisma_client
if prisma_client: if prisma_client:
@ -1022,8 +1029,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
- key: The generated api key - key: The generated api key
- expires: Datetime object for when key expires. - expires: Datetime object for when key expires.
""" """
data = await request.json() # data = await request.json()
duration_str = data.duration # Default to 1 hour if duration is not provided duration_str = data.duration # Default to 1 hour if duration is not provided
models = data.models # Default to an empty list (meaning allow token to call all models) models = data.models # Default to an empty list (meaning allow token to call all models)
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
@ -1042,8 +1048,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) @router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request, data: DeleteKeyRequest): async def delete_key_fn(request: Request, data: DeleteKeyRequest):
try: try:
data = await request.json()
keys = data.keys keys = data.keys
deleted_keys = await delete_verification_token(tokens=keys) deleted_keys = await delete_verification_token(tokens=keys)

View file

@ -5,8 +5,18 @@ class PrismaClient:
def __init__(self, database_url: str): def __init__(self, database_url: str):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url os.environ["DATABASE_URL"] = database_url
subprocess.run(['prisma', 'generate']) # Save the current working directory
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally:
os.chdir(original_dir)
# Now you can import the Prisma Client # Now you can import the Prisma Client
from prisma import Client from prisma import Client
self.db = Client() #Client to connect to Prisma db self.db = Client() #Client to connect to Prisma db

View file

@ -0,0 +1,7 @@
litellm_settings:
drop_params: True
set_verbose: True
general_settings:
master_key: "os.environ/PROXY_MASTER_KEY"
database_url: "os.environ/PROXY_DATABASE_URL" # [OPTIONAL] use for token-based auth to proxy

View file

@ -164,4 +164,4 @@ def test_chat_completion_optional_params():
pytest.fail("LiteLLM Proxy test failed. Exception", e) pytest.fail("LiteLLM Proxy test failed. Exception", e)
# Run the test # Run the test
test_chat_completion_optional_params() test_chat_completion_optional_params()

View file

@ -0,0 +1,66 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
# Configure logging
logging.basicConfig(
level=logging.DEBUG, # Set the desired logging level
format="%(asctime)s - %(levelname)s - %(message)s",
)
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
cwd = os.getcwd()
config_fp = f"{cwd}/test_config.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI()
app.include_router(router) # Include your router in the test app
@app.on_event("startup")
async def wrapper_startup_event():
await startup_event()
# Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app)
@pytest.fixture(autouse=True)
def client():
with TestClient(app) as client:
yield client
def test_add_new_key(client):
try:
# Your test data
test_data = {
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": "20m"
}
print("testing proxy server")
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/key/generate", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
print(f"Received response: {result}")
except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e)
# # Run the test - only runs via pytest

View file

@ -2421,8 +2421,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
return model, custom_llm_provider, dynamic_api_key, api_base return model, custom_llm_provider, dynamic_api_key, api_base
if api_key and api_key.startswith("os.environ/"): if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "") dynamic_api_key = get_secret(api_key)
dynamic_api_key = get_secret(api_key_env_name)
# check if llm provider part of model name # check if llm provider part of model name
if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list: if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list:
custom_llm_provider = model.split("/", 1)[0] custom_llm_provider = model.split("/", 1)[0]
@ -4722,7 +4721,9 @@ def litellm_telemetry(data):
######### Secret Manager ############################ ######### Secret Manager ############################
# checks if user has passed in a secret manager client # checks if user has passed in a secret manager client
# if passed in then checks the secret there # if passed in then checks the secret there
def get_secret(secret_name): def get_secret(secret_name: str):
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
if litellm.secret_manager_client is not None: if litellm.secret_manager_client is not None:
# TODO: check which secret manager is being used # TODO: check which secret manager is being used
# currently only supports Infisical # currently only supports Infisical