forked from phoenix/litellm-mirror
fix(proxy_server.py): fix prisma client connection error
This commit is contained in:
parent
6b708347f3
commit
74f6f6489a
8 changed files with 29 additions and 46 deletions
|
@ -307,9 +307,8 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def prisma_setup(database_url: Optional[str]):
|
async def prisma_setup(database_url: Optional[str]):
|
||||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||||
|
|
||||||
if (
|
if (
|
||||||
database_url is not None and prisma_client is None
|
database_url is not None and prisma_client is None
|
||||||
): # don't re-initialize prisma client after initial init
|
): # don't re-initialize prisma client after initial init
|
||||||
|
@ -321,6 +320,8 @@ def prisma_setup(database_url: Optional[str]):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}"
|
f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}"
|
||||||
)
|
)
|
||||||
|
if prisma_client is not None and prisma_client.db.is_connected() == False:
|
||||||
|
await prisma_client.connect()
|
||||||
|
|
||||||
|
|
||||||
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
|
@ -534,6 +535,7 @@ class ProxyConfig:
|
||||||
prisma_client is not None
|
prisma_client is not None
|
||||||
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
|
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
|
||||||
):
|
):
|
||||||
|
await prisma_setup(database_url=None) # in case it's not been connected yet
|
||||||
_tasks = []
|
_tasks = []
|
||||||
keys = [
|
keys = [
|
||||||
"model_list",
|
"model_list",
|
||||||
|
@ -761,7 +763,7 @@ class ProxyConfig:
|
||||||
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
|
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
|
||||||
database_url = litellm.get_secret(database_url)
|
database_url = litellm.get_secret(database_url)
|
||||||
print_verbose(f"RETRIEVED DB URL: {database_url}")
|
print_verbose(f"RETRIEVED DB URL: {database_url}")
|
||||||
prisma_setup(database_url=database_url)
|
await prisma_setup(database_url=database_url)
|
||||||
## COST TRACKING ##
|
## COST TRACKING ##
|
||||||
cost_tracking()
|
cost_tracking()
|
||||||
### MASTER KEY ###
|
### MASTER KEY ###
|
||||||
|
@ -930,7 +932,7 @@ def save_worker_config(**data):
|
||||||
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
||||||
|
|
||||||
|
|
||||||
def initialize(
|
async def initialize(
|
||||||
model=None,
|
model=None,
|
||||||
alias=None,
|
alias=None,
|
||||||
api_base=None,
|
api_base=None,
|
||||||
|
@ -948,13 +950,19 @@ def initialize(
|
||||||
use_queue=False,
|
use_queue=False,
|
||||||
config=None,
|
config=None,
|
||||||
):
|
):
|
||||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth
|
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
|
||||||
generate_feedback_box()
|
generate_feedback_box()
|
||||||
user_model = model
|
user_model = model
|
||||||
user_debug = debug
|
user_debug = debug
|
||||||
if debug == True: # this needs to be first, so users can see Router init debugg
|
if debug == True: # this needs to be first, so users can see Router init debugg
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
dynamic_config = {"general": {}, user_model: {}}
|
dynamic_config = {"general": {}, user_model: {}}
|
||||||
|
if config:
|
||||||
|
(
|
||||||
|
llm_router,
|
||||||
|
llm_model_list,
|
||||||
|
general_settings,
|
||||||
|
) = await proxy_config.load_config(router=llm_router, config_file_path=config)
|
||||||
if headers: # model-specific param
|
if headers: # model-specific param
|
||||||
user_headers = headers
|
user_headers = headers
|
||||||
dynamic_config[user_model]["headers"] = headers
|
dynamic_config[user_model]["headers"] = headers
|
||||||
|
@ -1095,28 +1103,11 @@ async def startup_event():
|
||||||
print_verbose(f"worker_config: {worker_config}")
|
print_verbose(f"worker_config: {worker_config}")
|
||||||
# check if it's a valid file path
|
# check if it's a valid file path
|
||||||
if os.path.isfile(worker_config):
|
if os.path.isfile(worker_config):
|
||||||
if worker_config.get("config", None) is not None:
|
await initialize(**worker_config)
|
||||||
(
|
|
||||||
llm_router,
|
|
||||||
llm_model_list,
|
|
||||||
general_settings,
|
|
||||||
) = await proxy_config.load_config(
|
|
||||||
router=llm_router, config_file_path=worker_config.pop("config")
|
|
||||||
)
|
|
||||||
initialize(**worker_config)
|
|
||||||
else:
|
else:
|
||||||
# if not, assume it's a json string
|
# if not, assume it's a json string
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
if worker_config.get("config", None) is not None:
|
await initialize(**worker_config)
|
||||||
(
|
|
||||||
llm_router,
|
|
||||||
llm_model_list,
|
|
||||||
general_settings,
|
|
||||||
) = await proxy_config.load_config(
|
|
||||||
router=llm_router, config_file_path=worker_config.pop("config")
|
|
||||||
)
|
|
||||||
initialize(**worker_config)
|
|
||||||
|
|
||||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||||
|
|
||||||
if use_background_health_checks:
|
if use_background_health_checks:
|
||||||
|
@ -1124,10 +1115,6 @@ async def startup_event():
|
||||||
_run_background_health_check()
|
_run_background_health_check()
|
||||||
) # start the background health check coroutine.
|
) # start the background health check coroutine.
|
||||||
|
|
||||||
print_verbose(f"prisma client - {prisma_client}")
|
|
||||||
if prisma_client is not None:
|
|
||||||
await prisma_client.connect()
|
|
||||||
|
|
||||||
if prisma_client is not None and master_key is not None:
|
if prisma_client is not None and master_key is not None:
|
||||||
# add master key to db
|
# add master key to db
|
||||||
await generate_key_helper_fn(
|
await generate_key_helper_fn(
|
||||||
|
@ -1331,7 +1318,7 @@ async def chat_completion(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||||
):
|
):
|
||||||
global general_settings, user_debug, proxy_logging_obj
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
try:
|
try:
|
||||||
data = {}
|
data = {}
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
|
@ -255,7 +255,6 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
## init logging object
|
## init logging object
|
||||||
self.proxy_logging_obj = proxy_logging_obj
|
self.proxy_logging_obj = proxy_logging_obj
|
||||||
self.connected = False
|
|
||||||
os.environ["DATABASE_URL"] = database_url
|
os.environ["DATABASE_URL"] = database_url
|
||||||
# Save the current working directory
|
# Save the current working directory
|
||||||
original_dir = os.getcwd()
|
original_dir = os.getcwd()
|
||||||
|
@ -536,11 +535,7 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
try:
|
try:
|
||||||
if self.connected == False:
|
|
||||||
await self.db.connect()
|
await self.db.connect()
|
||||||
self.connected = True
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||||
|
@ -558,7 +553,6 @@ class PrismaClient:
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
try:
|
try:
|
||||||
await self.db.disconnect()
|
await self.db.disconnect()
|
||||||
self.connected = False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest, asyncio
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
@ -22,6 +22,7 @@ from litellm.proxy.proxy_server import (
|
||||||
router,
|
router,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
initialize,
|
initialize,
|
||||||
|
ProxyConfig,
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ def client():
|
||||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
initialize(config=config_fp)
|
asyncio.run(initialize(config=config_fp))
|
||||||
|
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from litellm.proxy.proxy_server import (
|
||||||
router,
|
router,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
initialize,
|
initialize,
|
||||||
|
startup_event,
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
@ -39,8 +40,8 @@ python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
|
||||||
def client():
|
def client():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
||||||
initialize(config=config_fp)
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
asyncio.run(initialize(config=config_fp))
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ from litellm.proxy.proxy_server import (
|
||||||
def client():
|
def client():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
|
config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
|
||||||
initialize(config=config_fp)
|
asyncio.run(initialize(config=config_fp))
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest, logging
|
import pytest, logging, asyncio
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
@ -46,7 +46,7 @@ def client_no_auth():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
initialize(config=config_fp, debug=True)
|
asyncio.run(initialize(config=config_fp, debug=True))
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest, logging
|
import pytest, logging, asyncio
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
@ -45,7 +45,7 @@ def client_no_auth():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
initialize(config=config_fp, debug=True)
|
asyncio.run(initialize(config=config_fp, debug=True))
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import os, io
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest, logging
|
import pytest, logging, asyncio
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
@ -47,7 +47,7 @@ def client_no_auth():
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_cloudflare_azure_with_cache_config.yaml"
|
config_fp = f"{filepath}/test_configs/test_cloudflare_azure_with_cache_config.yaml"
|
||||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
initialize(config=config_fp, debug=True)
|
asyncio.run(initialize(config=config_fp, debug=True))
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue