diff --git a/litellm/llms/custom_httpx/httpx_handler.py b/litellm/llms/custom_httpx/httpx_handler.py new file mode 100644 index 0000000000..3f3bd09bac --- /dev/null +++ b/litellm/llms/custom_httpx/httpx_handler.py @@ -0,0 +1,38 @@ +from typing import Optional +import httpx + + +class HTTPHandler: + def __init__(self, concurrent_limit=1000): + # Create a client with a connection pool + self.client = httpx.AsyncClient( + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ) + ) + + async def close(self): + # Close the client when you're done with it + await self.client.aclose() + + async def get( + self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None + ): + response = await self.client.get(url, params=params, headers=headers) + return response + + async def post( + self, + url: str, + data: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + ): + try: + response = await self.client.post( + url, data=data, params=params, headers=headers + ) + return response + except Exception as e: + raise e diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index bd277bbdfc..07a24dd7e8 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,21 +1,22 @@ model_list: -- model_name: fake_openai +- model_name: fake-openai-endpoint litellm_params: model: openai/my-fake-model api_key: my-fake-key - api_base: http://0.0.0.0:8080 + api_base: https://exampleopenaiendpoint-production.up.railway.app/ - model_name: gpt-3.5-turbo litellm_params: model: gpt-3.5-turbo-1106 api_key: os.environ/OPENAI_API_KEY -litellm_settings: - cache: true - cache_params: - type: redis - callbacks: ["batch_redis_requests"] - # success_callbacks: ["langfuse"] +# litellm_settings: +# cache: true +# cache_params: +# type: redis +# callbacks: ["batch_redis_requests"] +# # success_callbacks: ["langfuse"] general_settings: master_key: sk-1234 + disable_spend_logs: true database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require" \ No newline at end of file diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 08ffc0955b..4689ffe7bf 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -6,7 +6,6 @@ Currently only supports admin. JWT token must have 'litellm_proxy_admin' in scope. """ -import httpx import jwt import json import os @@ -14,42 +13,10 @@ from litellm.caching import DualCache from litellm._logging import verbose_proxy_logger from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable from litellm.proxy.utils import PrismaClient +from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from typing import Optional -class HTTPHandler: - def __init__(self, concurrent_limit=1000): - # Create a client with a connection pool - self.client = httpx.AsyncClient( - limits=httpx.Limits( - max_connections=concurrent_limit, - max_keepalive_connections=concurrent_limit, - ) - ) - - async def close(self): - # Close the client when you're done with it - await self.client.aclose() - - async def get( - self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None - ): - response = await self.client.get(url, params=params, headers=headers) - return response - - async def post( - self, - url: str, - data: Optional[dict] = None, - params: Optional[dict] = None, - headers: Optional[dict] = None, - ): - response = await self.client.post( - url, data=data, params=params, headers=headers - ) - return response - - class JWTHandler: """ - treat the sub id passed in as the user id diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index b1d7b8026c..b8d7926963 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -21,8 +21,6 @@ telemetry = None def append_query_params(url, params): from litellm._logging import verbose_proxy_logger - from litellm._logging import verbose_proxy_logger - verbose_proxy_logger.debug(f"url: {url}") verbose_proxy_logger.debug(f"params: {params}") parsed_url = urlparse.urlparse(url) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8fa2862f27..f2918a04da 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -97,7 +97,6 @@ from litellm.proxy.utils import ( _is_projected_spend_over_limit, _get_projected_spend_over_limit, update_spend, - monitor_spend_list, ) from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager @@ -118,6 +117,7 @@ from litellm.proxy.auth.auth_checks import ( allowed_routes_check, get_actual_routes, ) +from litellm.llms.custom_httpx.httpx_handler import HTTPHandler try: from litellm._version import version @@ -305,6 +305,8 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) async_result = None celery_app_conn = None celery_fn = None # Redis Queue for handling requests +### DB WRITER ### +db_writer_client: Optional[HTTPHandler] = None ### logger ### @@ -1363,7 +1365,15 @@ async def update_database( payload["spend"] = response_cost if prisma_client is not None: - await prisma_client.insert_data(data=payload, table_name="spend") + prisma_client.spend_log_transactons.append(payload) + # if db_writer_client is not None: + # print("Tries to make call") + # response = await db_writer_client.post( + # url="http://0.0.0.0:3000/spend/update", + # data=json.dumps(payload), + # headers={"Content-Type": "application/json"}, + # ) + # print(f"response: {response}") except Exception as e: verbose_proxy_logger.debug( f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}" @@ -2693,7 +2703,7 @@ def on_backoff(details): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client import json ### LOAD MASTER KEY ### @@ -2726,6 +2736,8 @@ async def startup_event(): ## COST TRACKING ## cost_tracking() + db_writer_client = HTTPHandler() + proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made ## JWT AUTH ## @@ -2836,7 +2848,7 @@ async def startup_event(): update_spend, "interval", seconds=batch_writing_interval, - args=[prisma_client], + args=[prisma_client, db_writer_client], ) scheduler.start() @@ -7985,6 +7997,8 @@ async def shutdown_event(): await jwt_handler.close() + if db_writer_client is not None: + await db_writer_client.close() ## RESET CUSTOM VARIABLES ## cleanup_router_config_variables() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ba8d708045..1d5f5f8194 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -13,6 +13,7 @@ from litellm.proxy._types import ( Member, ) from litellm.caching import DualCache +from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, ) @@ -1716,6 +1717,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time): # hash the api_key api_key = hash_token(api_key) + # jsonify datetime object + # if isinstance(start_time, datetime): + # start_time = start_time.isoformat() + # if isinstance(end_time, datetime): + # end_time = end_time.isoformat() # clean up litellm metadata if isinstance(metadata, dict): clean_metadata = {} @@ -1866,9 +1872,7 @@ async def reset_budget(prisma_client: PrismaClient): ) -async def update_spend( - prisma_client: PrismaClient, -): +async def update_spend(prisma_client: PrismaClient, db_writer_client: HTTPHandler): """ Batch write updates to db. @@ -1995,13 +1999,23 @@ async def update_spend( except Exception as e: raise e + ### UPDATE SPEND LOGS ### + # if len(prisma_client.spend_log_transactons) > 0: + # response = await db_writer_client.post( + # url="http://0.0.0.0:3000/spend/update", + # data=prisma_client.spend_log_transactons, + # headers={"Content-Type": "application/json"}, + # ) + # if response.status_code == 200: + # prisma_client.spend_log_transactons = [] -async def monitor_spend_list(prisma_client: PrismaClient): - """ - Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db - """ - if len(prisma_client.user_list_transactons) > 10000: - await update_spend(prisma_client=prisma_client) + +# async def monitor_spend_list(prisma_client: PrismaClient): +# """ +# Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db +# """ +# if len(prisma_client.user_list_transactons) > 10000: +# await update_spend(prisma_client=prisma_client) async def _read_request_body(request):