mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(proxy_server.py): handle initializing prisma / db connection just once
This commit is contained in:
parent
4a364bcbc0
commit
7a669a36d2
1 changed files with 54 additions and 49 deletions
|
@ -1,4 +1,4 @@
|
||||||
import sys, os, platform, time, copy, re
|
import sys, os, platform, time, copy, re, asyncio
|
||||||
import threading, ast
|
import threading, ast
|
||||||
import shutil, random, traceback, requests
|
import shutil, random, traceback, requests
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
@ -161,7 +161,6 @@ async def user_api_key_auth(request: Request):
|
||||||
return
|
return
|
||||||
print(f"prisma_client: {prisma_client}")
|
print(f"prisma_client: {prisma_client}")
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
await prisma_client.connect()
|
|
||||||
valid_token = await prisma_client.litellm_verificationtoken.find_first(
|
valid_token = await prisma_client.litellm_verificationtoken.find_first(
|
||||||
where={
|
where={
|
||||||
"token": api_key,
|
"token": api_key,
|
||||||
|
@ -174,7 +173,7 @@ async def user_api_key_auth(request: Request):
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(f"An exception occurred - {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail={"error": "invalid user key"},
|
detail={"error": "invalid user key"},
|
||||||
|
@ -291,6 +290,49 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
|
|
||||||
return router, model_list, server_settings
|
return router, model_list, server_settings
|
||||||
|
|
||||||
|
async def generate_key_helper_fn(duration_str: str):
|
||||||
|
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||||
|
def _duration_in_seconds(duration: str):
|
||||||
|
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||||
|
if not match:
|
||||||
|
raise ValueError("Invalid duration format")
|
||||||
|
|
||||||
|
value, unit = match.groups()
|
||||||
|
value = int(value)
|
||||||
|
|
||||||
|
if unit == "s":
|
||||||
|
return value
|
||||||
|
elif unit == "m":
|
||||||
|
return value * 60
|
||||||
|
elif unit == "h":
|
||||||
|
return value * 3600
|
||||||
|
elif unit == "d":
|
||||||
|
return value * 86400
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported duration unit")
|
||||||
|
|
||||||
|
duration = _duration_in_seconds(duration=duration_str)
|
||||||
|
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||||
|
try:
|
||||||
|
db = prisma_client
|
||||||
|
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||||
|
verification_token_data = {
|
||||||
|
"token": token,
|
||||||
|
"expires": expires
|
||||||
|
}
|
||||||
|
new_verification_token = await db.litellm_verificationtoken.create(
|
||||||
|
{**verification_token_data}
|
||||||
|
)
|
||||||
|
print(f"new_verification_token: {new_verification_token}")
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
|
||||||
|
|
||||||
|
async def generate_key_cli_task(duration_str):
|
||||||
|
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
||||||
|
await task
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
#### DEPRECATED ####
|
#### DEPRECATED ####
|
||||||
try:
|
try:
|
||||||
|
@ -412,7 +454,7 @@ def initialize(
|
||||||
add_function_to_prompt,
|
add_function_to_prompt,
|
||||||
headers,
|
headers,
|
||||||
save,
|
save,
|
||||||
config
|
config,
|
||||||
):
|
):
|
||||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
|
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
|
||||||
generate_feedback_box()
|
generate_feedback_box()
|
||||||
|
@ -504,13 +546,15 @@ def litellm_completion(*args, **kwargs):
|
||||||
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
|
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def startup_event():
|
async def startup_event():
|
||||||
|
global prisma_client
|
||||||
import json
|
import json
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
initialize(**worker_config)
|
initialize(**worker_config)
|
||||||
|
if prisma_client:
|
||||||
|
await prisma_client.connect()
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
|
@ -625,51 +669,12 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def generate_key(request: Request):
|
async def generate_key_fn(request: Request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
|
||||||
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
||||||
|
response = await generate_key_helper_fn(duration_str=duration_str)
|
||||||
def _duration_in_seconds(duration: str):
|
return {"token": response["token"], "expires": response["expires"]}
|
||||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
|
||||||
if not match:
|
|
||||||
raise ValueError("Invalid duration format")
|
|
||||||
|
|
||||||
value, unit = match.groups()
|
|
||||||
value = int(value)
|
|
||||||
|
|
||||||
if unit == "s":
|
|
||||||
return value
|
|
||||||
elif unit == "m":
|
|
||||||
return value * 60
|
|
||||||
elif unit == "h":
|
|
||||||
return value * 3600
|
|
||||||
elif unit == "d":
|
|
||||||
return value * 86400
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported duration unit")
|
|
||||||
|
|
||||||
duration = _duration_in_seconds(duration=duration_str)
|
|
||||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
|
||||||
try:
|
|
||||||
db = prisma_client
|
|
||||||
await db.connect()
|
|
||||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
|
||||||
print(dir(db))
|
|
||||||
verification_token_data = {
|
|
||||||
"token": token,
|
|
||||||
"expires": expires
|
|
||||||
}
|
|
||||||
new_verification_token = await db.litellm_verificationtoken.create(
|
|
||||||
{**verification_token_data}
|
|
||||||
)
|
|
||||||
print(f"new_verification_token: {new_verification_token}")
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
||||||
|
|
||||||
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue