forked from phoenix/litellm-mirror
fix(proxy_server.py): fix /key/generate post endpoint
This commit is contained in:
parent
d7d8c5f6e6
commit
63e55f1865
6 changed files with 115 additions and 27 deletions
|
@ -213,11 +213,11 @@ class GenerateKeyRequest(BaseModel):
|
||||||
aliases: dict = {}
|
aliases: dict = {}
|
||||||
config: dict = {}
|
config: dict = {}
|
||||||
spend: int = 0
|
spend: int = 0
|
||||||
user_id: Optional[str]
|
user_id: Optional[str] = None
|
||||||
|
|
||||||
class GenerateKeyResponse(BaseModel):
|
class GenerateKeyResponse(BaseModel):
|
||||||
key: str
|
key: str
|
||||||
expires: str
|
expires: datetime
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
class _DeleteKeyObject(BaseModel):
|
class _DeleteKeyObject(BaseModel):
|
||||||
|
@ -277,6 +277,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
"api_key": None
|
"api_key": None
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
|
if api_key is None:
|
||||||
|
raise Exception("No api key passed in.")
|
||||||
route = request.url.path
|
route = request.url.path
|
||||||
|
|
||||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||||
|
@ -491,8 +493,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
## PRINT YAML FOR CONFIRMING IT WORKS
|
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||||
printed_yaml = copy.deepcopy(config)
|
printed_yaml = copy.deepcopy(config)
|
||||||
printed_yaml.pop("environment_variables", None)
|
printed_yaml.pop("environment_variables", None)
|
||||||
for model in printed_yaml["model_list"]:
|
if "model_list" in printed_yaml:
|
||||||
model["litellm_params"].pop("api_key", None)
|
for model in printed_yaml["model_list"]:
|
||||||
|
model["litellm_params"].pop("api_key", None)
|
||||||
|
|
||||||
print(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
|
print(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
|
||||||
|
|
||||||
|
@ -507,22 +510,24 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
if general_settings is None:
|
if general_settings is None:
|
||||||
general_settings = {}
|
general_settings = {}
|
||||||
if general_settings:
|
if general_settings:
|
||||||
### MASTER KEY ###
|
### LOAD FROM AZURE KEY VAULT ###
|
||||||
master_key = general_settings.get("master_key", None)
|
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||||
master_key_env_name = master_key.replace("os.environ/", "")
|
|
||||||
master_key = os.getenv(master_key_env_name)
|
|
||||||
### CONNECT TO DATABASE ###
|
### CONNECT TO DATABASE ###
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
|
if database_url and database_url.startswith("os.environ/"):
|
||||||
|
database_url = litellm.get_secret(database_url)
|
||||||
prisma_setup(database_url=database_url)
|
prisma_setup(database_url=database_url)
|
||||||
## COST TRACKING ##
|
## COST TRACKING ##
|
||||||
cost_tracking()
|
cost_tracking()
|
||||||
### START REDIS QUEUE ###
|
### START REDIS QUEUE ###
|
||||||
use_queue = general_settings.get("use_queue", False)
|
use_queue = general_settings.get("use_queue", False)
|
||||||
celery_setup(use_queue=use_queue)
|
celery_setup(use_queue=use_queue)
|
||||||
### LOAD FROM AZURE KEY VAULT ###
|
### MASTER KEY ###
|
||||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
master_key = general_settings.get("master_key", None)
|
||||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
|
master_key = litellm.get_secret(master_key)
|
||||||
|
|
||||||
#### OpenTelemetry Logging (OTEL) ########
|
#### OpenTelemetry Logging (OTEL) ########
|
||||||
otel_logging = general_settings.get("otel", False)
|
otel_logging = general_settings.get("otel", False)
|
||||||
|
@ -540,9 +545,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
print(f"{blue_color_code}\nSetting Cache on Proxy")
|
print(f"{blue_color_code}\nSetting Cache on Proxy")
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
cache_type = value["type"]
|
cache_type = value["type"]
|
||||||
cache_host = os.environ.get("REDIS_HOST")
|
cache_host = litellm.get_secret("REDIS_HOST")
|
||||||
cache_port = os.environ.get("REDIS_PORT")
|
cache_port = litellm.get_secret("REDIS_PORT")
|
||||||
cache_password = os.environ.get("REDIS_PASSWORD")
|
cache_password = litellm.get_secret("REDIS_PASSWORD")
|
||||||
|
|
||||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||||
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
|
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
|
||||||
|
@ -794,12 +799,14 @@ def litellm_completion(*args, **kwargs):
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@app.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client, master_key
|
global prisma_client, master_key
|
||||||
import json
|
import json
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
|
print(f"worker_config: {worker_config}")
|
||||||
initialize(**worker_config)
|
initialize(**worker_config)
|
||||||
|
print(f"prisma client - {prisma_client}")
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
await prisma_client.connect()
|
await prisma_client.connect()
|
||||||
|
|
||||||
|
@ -807,7 +814,7 @@ async def startup_event():
|
||||||
# add master key to db
|
# add master key to db
|
||||||
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@router.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
global prisma_client
|
global prisma_client
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
|
@ -1022,8 +1029,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
||||||
- key: The generated api key
|
- key: The generated api key
|
||||||
- expires: Datetime object for when key expires.
|
- expires: Datetime object for when key expires.
|
||||||
"""
|
"""
|
||||||
data = await request.json()
|
# data = await request.json()
|
||||||
|
|
||||||
duration_str = data.duration # Default to 1 hour if duration is not provided
|
duration_str = data.duration # Default to 1 hour if duration is not provided
|
||||||
models = data.models # Default to an empty list (meaning allow token to call all models)
|
models = data.models # Default to an empty list (meaning allow token to call all models)
|
||||||
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||||
|
@ -1042,8 +1048,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
||||||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
|
||||||
|
|
||||||
keys = data.keys
|
keys = data.keys
|
||||||
|
|
||||||
deleted_keys = await delete_verification_token(tokens=keys)
|
deleted_keys = await delete_verification_token(tokens=keys)
|
||||||
|
|
|
@ -5,8 +5,18 @@ class PrismaClient:
|
||||||
def __init__(self, database_url: str):
|
def __init__(self, database_url: str):
|
||||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||||
os.environ["DATABASE_URL"] = database_url
|
os.environ["DATABASE_URL"] = database_url
|
||||||
subprocess.run(['prisma', 'generate'])
|
# Save the current working directory
|
||||||
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
original_dir = os.getcwd()
|
||||||
|
# set the working directory to where this script is
|
||||||
|
abspath = os.path.abspath(__file__)
|
||||||
|
dname = os.path.dirname(abspath)
|
||||||
|
os.chdir(dname)
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(['prisma', 'generate'])
|
||||||
|
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
# Now you can import the Prisma Client
|
# Now you can import the Prisma Client
|
||||||
from prisma import Client
|
from prisma import Client
|
||||||
self.db = Client() #Client to connect to Prisma db
|
self.db = Client() #Client to connect to Prisma db
|
||||||
|
|
7
litellm/tests/test_config.yaml
Normal file
7
litellm/tests/test_config.yaml
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
||||||
|
set_verbose: True
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: "os.environ/PROXY_MASTER_KEY"
|
||||||
|
database_url: "os.environ/PROXY_DATABASE_URL" # [OPTIONAL] use for token-based auth to proxy
|
|
@ -164,4 +164,4 @@ def test_chat_completion_optional_params():
|
||||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
test_chat_completion_optional_params()
|
test_chat_completion_optional_params()
|
||||||
|
|
66
litellm/tests/test_proxy_server_keys.py
Normal file
66
litellm/tests/test_proxy_server_keys.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os, io
|
||||||
|
|
||||||
|
# this file is to test litellm/proxy
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest, logging
|
||||||
|
import litellm
|
||||||
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
|
from litellm import RateLimitError
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG, # Set the desired logging level
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
# test /chat/completion request to the proxy
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
||||||
|
cwd = os.getcwd()
|
||||||
|
config_fp = f"{cwd}/test_config.yaml"
|
||||||
|
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router) # Include your router in the test app
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def wrapper_startup_event():
|
||||||
|
await startup_event()
|
||||||
|
|
||||||
|
# Here you create a fixture that will be used by your tests
|
||||||
|
# Make sure the fixture returns TestClient(app)
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def client():
|
||||||
|
with TestClient(app) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
def test_add_new_key(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"],
|
||||||
|
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||||
|
"duration": "20m"
|
||||||
|
}
|
||||||
|
print("testing proxy server")
|
||||||
|
# Your bearer token
|
||||||
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}"
|
||||||
|
}
|
||||||
|
response = client.post("/key/generate", json=test_data, headers=headers)
|
||||||
|
print(f"response: {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
print(f"Received response: {result}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
||||||
|
|
||||||
|
# # Run the test - only runs via pytest
|
|
@ -2421,8 +2421,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
|
|
||||||
if api_key and api_key.startswith("os.environ/"):
|
if api_key and api_key.startswith("os.environ/"):
|
||||||
api_key_env_name = api_key.replace("os.environ/", "")
|
dynamic_api_key = get_secret(api_key)
|
||||||
dynamic_api_key = get_secret(api_key_env_name)
|
|
||||||
# check if llm provider part of model name
|
# check if llm provider part of model name
|
||||||
if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list:
|
if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list:
|
||||||
custom_llm_provider = model.split("/", 1)[0]
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
|
@ -4722,7 +4721,9 @@ def litellm_telemetry(data):
|
||||||
######### Secret Manager ############################
|
######### Secret Manager ############################
|
||||||
# checks if user has passed in a secret manager client
|
# checks if user has passed in a secret manager client
|
||||||
# if passed in then checks the secret there
|
# if passed in then checks the secret there
|
||||||
def get_secret(secret_name):
|
def get_secret(secret_name: str):
|
||||||
|
if secret_name.startswith("os.environ/"):
|
||||||
|
secret_name = secret_name.replace("os.environ/", "")
|
||||||
if litellm.secret_manager_client is not None:
|
if litellm.secret_manager_client is not None:
|
||||||
# TODO: check which secret manager is being used
|
# TODO: check which secret manager is being used
|
||||||
# currently only supports Infisical
|
# currently only supports Infisical
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue