fix(proxy_server.py): fix prisma client connection error

This commit is contained in:
Krrish Dholakia 2024-01-04 18:28:18 +05:30
parent 6b708347f3
commit 74f6f6489a
8 changed files with 29 additions and 46 deletions

View file

@ -307,9 +307,8 @@ async def user_api_key_auth(
) )
def prisma_setup(database_url: Optional[str]): async def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj, user_api_key_cache global prisma_client, proxy_logging_obj, user_api_key_cache
if ( if (
database_url is not None and prisma_client is None database_url is not None and prisma_client is None
): # don't re-initialize prisma client after initial init ): # don't re-initialize prisma client after initial init
@ -321,6 +320,8 @@ def prisma_setup(database_url: Optional[str]):
print_verbose( print_verbose(
f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}" f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}"
) )
if prisma_client is not None and prisma_client.db.is_connected() == False:
await prisma_client.connect()
def load_from_azure_key_vault(use_azure_key_vault: bool = False): def load_from_azure_key_vault(use_azure_key_vault: bool = False):
@ -534,6 +535,7 @@ class ProxyConfig:
prisma_client is not None prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
): ):
await prisma_setup(database_url=None) # in case it's not been connected yet
_tasks = [] _tasks = []
keys = [ keys = [
"model_list", "model_list",
@ -761,7 +763,7 @@ class ProxyConfig:
print_verbose(f"GOING INTO LITELLM.GET_SECRET!") print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url) database_url = litellm.get_secret(database_url)
print_verbose(f"RETRIEVED DB URL: {database_url}") print_verbose(f"RETRIEVED DB URL: {database_url}")
prisma_setup(database_url=database_url) await prisma_setup(database_url=database_url)
## COST TRACKING ## ## COST TRACKING ##
cost_tracking() cost_tracking()
### MASTER KEY ### ### MASTER KEY ###
@ -930,7 +932,7 @@ def save_worker_config(**data):
os.environ["WORKER_CONFIG"] = json.dumps(data) os.environ["WORKER_CONFIG"] = json.dumps(data)
def initialize( async def initialize(
model=None, model=None,
alias=None, alias=None,
api_base=None, api_base=None,
@ -948,13 +950,19 @@ def initialize(
use_queue=False, use_queue=False,
config=None, config=None,
): ):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
generate_feedback_box() generate_feedback_box()
user_model = model user_model = model
user_debug = debug user_debug = debug
if debug == True: # this needs to be first, so users can see Router init debugg if debug == True: # this needs to be first, so users can see Router init debugg
litellm.set_verbose = True litellm.set_verbose = True
dynamic_config = {"general": {}, user_model: {}} dynamic_config = {"general": {}, user_model: {}}
if config:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(router=llm_router, config_file_path=config)
if headers: # model-specific param if headers: # model-specific param
user_headers = headers user_headers = headers
dynamic_config[user_model]["headers"] = headers dynamic_config[user_model]["headers"] = headers
@ -1095,28 +1103,11 @@ async def startup_event():
print_verbose(f"worker_config: {worker_config}") print_verbose(f"worker_config: {worker_config}")
# check if it's a valid file path # check if it's a valid file path
if os.path.isfile(worker_config): if os.path.isfile(worker_config):
if worker_config.get("config", None) is not None: await initialize(**worker_config)
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config)
else: else:
# if not, assume it's a json string # if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG")) worker_config = json.loads(os.getenv("WORKER_CONFIG"))
if worker_config.get("config", None) is not None: await initialize(**worker_config)
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config)
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
if use_background_health_checks: if use_background_health_checks:
@ -1124,10 +1115,6 @@ async def startup_event():
_run_background_health_check() _run_background_health_check()
) # start the background health check coroutine. ) # start the background health check coroutine.
print_verbose(f"prisma client - {prisma_client}")
if prisma_client is not None:
await prisma_client.connect()
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
await generate_key_helper_fn( await generate_key_helper_fn(
@ -1331,7 +1318,7 @@ async def chat_completion(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
background_tasks: BackgroundTasks = BackgroundTasks(), background_tasks: BackgroundTasks = BackgroundTasks(),
): ):
global general_settings, user_debug, proxy_logging_obj global general_settings, user_debug, proxy_logging_obj, llm_model_list
try: try:
data = {} data = {}
body = await request.body() body = await request.body()

View file

@ -255,7 +255,6 @@ class PrismaClient:
) )
## init logging object ## init logging object
self.proxy_logging_obj = proxy_logging_obj self.proxy_logging_obj = proxy_logging_obj
self.connected = False
os.environ["DATABASE_URL"] = database_url os.environ["DATABASE_URL"] = database_url
# Save the current working directory # Save the current working directory
original_dir = os.getcwd() original_dir = os.getcwd()
@ -536,11 +535,7 @@ class PrismaClient:
) )
async def connect(self): async def connect(self):
try: try:
if self.connected == False: await self.db.connect()
await self.db.connect()
self.connected = True
else:
return
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)
@ -558,7 +553,6 @@ class PrismaClient:
async def disconnect(self): async def disconnect(self):
try: try:
await self.db.disconnect() await self.db.disconnect()
self.connected = False
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)

View file

@ -10,7 +10,7 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
@ -22,6 +22,7 @@ from litellm.proxy.proxy_server import (
router, router,
save_worker_config, save_worker_config,
initialize, initialize,
ProxyConfig,
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
@ -36,7 +37,7 @@ def client():
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
app = FastAPI() app = FastAPI()
initialize(config=config_fp) asyncio.run(initialize(config=config_fp))
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
return TestClient(app) return TestClient(app)

View file

@ -23,6 +23,7 @@ from litellm.proxy.proxy_server import (
router, router,
save_worker_config, save_worker_config,
initialize, initialize,
startup_event,
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
@ -39,8 +40,8 @@ python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
def client(): def client():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml" config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
initialize(config=config_fp)
app = FastAPI() app = FastAPI()
asyncio.run(initialize(config=config_fp))
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
return TestClient(app) return TestClient(app)

View file

@ -24,7 +24,7 @@ from litellm.proxy.proxy_server import (
def client(): def client():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_bad_config.yaml" config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
initialize(config=config_fp) asyncio.run(initialize(config=config_fp))
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
return TestClient(app) return TestClient(app)

View file

@ -10,7 +10,7 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest, logging import pytest, logging, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
@ -46,7 +46,7 @@ def client_no_auth():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
initialize(config=config_fp, debug=True) asyncio.run(initialize(config=config_fp, debug=True))
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app

View file

@ -10,7 +10,7 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest, logging import pytest, logging, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
@ -45,7 +45,7 @@ def client_no_auth():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
initialize(config=config_fp, debug=True) asyncio.run(initialize(config=config_fp, debug=True))
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app

View file

@ -12,7 +12,7 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest, logging import pytest, logging, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
@ -47,7 +47,7 @@ def client_no_auth():
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_cloudflare_azure_with_cache_config.yaml" config_fp = f"{filepath}/test_configs/test_cloudflare_azure_with_cache_config.yaml"
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
initialize(config=config_fp, debug=True) asyncio.run(initialize(config=config_fp, debug=True))
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app