feat(proxy/utils.py): enable updating db in a separate server

This commit is contained in:
Krrish Dholakia 2024-03-27 16:02:36 -07:00
parent 9b7383ac67
commit 1e856443e1
6 changed files with 89 additions and 57 deletions

View file

@ -0,0 +1,38 @@
from typing import Optional
import httpx
class HTTPHandler:
def __init__(self, concurrent_limit=1000):
# Create a client with a connection pool
self.client = httpx.AsyncClient(
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
)
)
async def close(self):
# Close the client when you're done with it
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = await self.client.get(url, params=params, headers=headers)
return response
async def post(
self,
url: str,
data: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
try:
response = await self.client.post(
url, data=data, params=params, headers=headers
)
return response
except Exception as e:
raise e

View file

@ -1,21 +1,22 @@
model_list: model_list:
- model_name: fake_openai - model_name: fake-openai-endpoint
litellm_params: litellm_params:
model: openai/my-fake-model model: openai/my-fake-model
api_key: my-fake-key api_key: my-fake-key
api_base: http://0.0.0.0:8080 api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: gpt-3.5-turbo-1106 model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
litellm_settings: # litellm_settings:
cache: true # cache: true
cache_params: # cache_params:
type: redis # type: redis
callbacks: ["batch_redis_requests"] # callbacks: ["batch_redis_requests"]
# success_callbacks: ["langfuse"] # # success_callbacks: ["langfuse"]
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
disable_spend_logs: true
database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require" database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require"

View file

@ -6,7 +6,6 @@ Currently only supports admin.
JWT token must have 'litellm_proxy_admin' in scope. JWT token must have 'litellm_proxy_admin' in scope.
""" """
import httpx
import jwt import jwt
import json import json
import os import os
@ -14,42 +13,10 @@ from litellm.caching import DualCache
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from typing import Optional from typing import Optional
class HTTPHandler:
def __init__(self, concurrent_limit=1000):
# Create a client with a connection pool
self.client = httpx.AsyncClient(
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
)
)
async def close(self):
# Close the client when you're done with it
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = await self.client.get(url, params=params, headers=headers)
return response
async def post(
self,
url: str,
data: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
response = await self.client.post(
url, data=data, params=params, headers=headers
)
return response
class JWTHandler: class JWTHandler:
""" """
- treat the sub id passed in as the user id - treat the sub id passed in as the user id

View file

@ -21,8 +21,6 @@ telemetry = None
def append_query_params(url, params): def append_query_params(url, params):
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.debug(f"url: {url}") verbose_proxy_logger.debug(f"url: {url}")
verbose_proxy_logger.debug(f"params: {params}") verbose_proxy_logger.debug(f"params: {params}")
parsed_url = urlparse.urlparse(url) parsed_url = urlparse.urlparse(url)

View file

@ -97,7 +97,6 @@ from litellm.proxy.utils import (
_is_projected_spend_over_limit, _is_projected_spend_over_limit,
_get_projected_spend_over_limit, _get_projected_spend_over_limit,
update_spend, update_spend,
monitor_spend_list,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
@ -118,6 +117,7 @@ from litellm.proxy.auth.auth_checks import (
allowed_routes_check, allowed_routes_check,
get_actual_routes, get_actual_routes,
) )
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
try: try:
from litellm._version import version from litellm._version import version
@ -305,6 +305,8 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
async_result = None async_result = None
celery_app_conn = None celery_app_conn = None
celery_fn = None # Redis Queue for handling requests celery_fn = None # Redis Queue for handling requests
### DB WRITER ###
db_writer_client: Optional[HTTPHandler] = None
### logger ### ### logger ###
@ -1363,7 +1365,15 @@ async def update_database(
payload["spend"] = response_cost payload["spend"] = response_cost
if prisma_client is not None: if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend") prisma_client.spend_log_transactons.append(payload)
# if db_writer_client is not None:
# print("Tries to make call")
# response = await db_writer_client.post(
# url="http://0.0.0.0:3000/spend/update",
# data=json.dumps(payload),
# headers={"Content-Type": "application/json"},
# )
# print(f"response: {response}")
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
@ -2693,7 +2703,7 @@ def on_backoff(details):
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client
import json import json
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
@ -2726,6 +2736,8 @@ async def startup_event():
## COST TRACKING ## ## COST TRACKING ##
cost_tracking() cost_tracking()
db_writer_client = HTTPHandler()
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
## JWT AUTH ## ## JWT AUTH ##
@ -2836,7 +2848,7 @@ async def startup_event():
update_spend, update_spend,
"interval", "interval",
seconds=batch_writing_interval, seconds=batch_writing_interval,
args=[prisma_client], args=[prisma_client, db_writer_client],
) )
scheduler.start() scheduler.start()
@ -7985,6 +7997,8 @@ async def shutdown_event():
await jwt_handler.close() await jwt_handler.close()
if db_writer_client is not None:
await db_writer_client.close()
## RESET CUSTOM VARIABLES ## ## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables() cleanup_router_config_variables()

View file

@ -13,6 +13,7 @@ from litellm.proxy._types import (
Member, Member,
) )
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
) )
@ -1716,6 +1717,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
# hash the api_key # hash the api_key
api_key = hash_token(api_key) api_key = hash_token(api_key)
# jsonify datetime object
# if isinstance(start_time, datetime):
# start_time = start_time.isoformat()
# if isinstance(end_time, datetime):
# end_time = end_time.isoformat()
# clean up litellm metadata # clean up litellm metadata
if isinstance(metadata, dict): if isinstance(metadata, dict):
clean_metadata = {} clean_metadata = {}
@ -1866,9 +1872,7 @@ async def reset_budget(prisma_client: PrismaClient):
) )
async def update_spend( async def update_spend(prisma_client: PrismaClient, db_writer_client: HTTPHandler):
prisma_client: PrismaClient,
):
""" """
Batch write updates to db. Batch write updates to db.
@ -1995,13 +1999,23 @@ async def update_spend(
except Exception as e: except Exception as e:
raise e raise e
### UPDATE SPEND LOGS ###
# if len(prisma_client.spend_log_transactons) > 0:
# response = await db_writer_client.post(
# url="http://0.0.0.0:3000/spend/update",
# data=prisma_client.spend_log_transactons,
# headers={"Content-Type": "application/json"},
# )
# if response.status_code == 200:
# prisma_client.spend_log_transactons = []
async def monitor_spend_list(prisma_client: PrismaClient):
""" # async def monitor_spend_list(prisma_client: PrismaClient):
Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db # """
""" # Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db
if len(prisma_client.user_list_transactons) > 10000: # """
await update_spend(prisma_client=prisma_client) # if len(prisma_client.user_list_transactons) > 10000:
# await update_spend(prisma_client=prisma_client)
async def _read_request_body(request): async def _read_request_body(request):