forked from phoenix/litellm-mirror
feat(proxy/utils.py): enable updating db in a separate server
This commit is contained in:
parent
9b7383ac67
commit
1e856443e1
6 changed files with 89 additions and 57 deletions
38
litellm/llms/custom_httpx/httpx_handler.py
Normal file
38
litellm/llms/custom_httpx/httpx_handler.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from typing import Optional
|
||||
import httpx
|
||||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(self, concurrent_limit=1000):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
# Close the client when you're done with it
|
||||
await self.client.aclose()
|
||||
|
||||
async def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
):
|
||||
response = await self.client.get(url, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
response = await self.client.post(
|
||||
url, data=data, params=params, headers=headers
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
|
@ -1,21 +1,22 @@
|
|||
model_list:
|
||||
- model_name: fake_openai
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: http://0.0.0.0:8080
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-1106
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
cache: true
|
||||
cache_params:
|
||||
type: redis
|
||||
callbacks: ["batch_redis_requests"]
|
||||
# success_callbacks: ["langfuse"]
|
||||
# litellm_settings:
|
||||
# cache: true
|
||||
# cache_params:
|
||||
# type: redis
|
||||
# callbacks: ["batch_redis_requests"]
|
||||
# # success_callbacks: ["langfuse"]
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
disable_spend_logs: true
|
||||
database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require"
|
|
@ -6,7 +6,6 @@ Currently only supports admin.
|
|||
JWT token must have 'litellm_proxy_admin' in scope.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import json
|
||||
import os
|
||||
|
@ -14,42 +13,10 @@ from litellm.caching import DualCache
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(self, concurrent_limit=1000):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
# Close the client when you're done with it
|
||||
await self.client.aclose()
|
||||
|
||||
async def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
):
|
||||
response = await self.client.get(url, params=params, headers=headers)
|
||||
return response
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
response = await self.client.post(
|
||||
url, data=data, params=params, headers=headers
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class JWTHandler:
|
||||
"""
|
||||
- treat the sub id passed in as the user id
|
||||
|
|
|
@ -21,8 +21,6 @@ telemetry = None
|
|||
def append_query_params(url, params):
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.debug(f"url: {url}")
|
||||
verbose_proxy_logger.debug(f"params: {params}")
|
||||
parsed_url = urlparse.urlparse(url)
|
||||
|
|
|
@ -97,7 +97,6 @@ from litellm.proxy.utils import (
|
|||
_is_projected_spend_over_limit,
|
||||
_get_projected_spend_over_limit,
|
||||
update_spend,
|
||||
monitor_spend_list,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
|
@ -118,6 +117,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
allowed_routes_check,
|
||||
get_actual_routes,
|
||||
)
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -305,6 +305,8 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
|||
async_result = None
|
||||
celery_app_conn = None
|
||||
celery_fn = None # Redis Queue for handling requests
|
||||
### DB WRITER ###
|
||||
db_writer_client: Optional[HTTPHandler] = None
|
||||
### logger ###
|
||||
|
||||
|
||||
|
@ -1363,7 +1365,15 @@ async def update_database(
|
|||
|
||||
payload["spend"] = response_cost
|
||||
if prisma_client is not None:
|
||||
await prisma_client.insert_data(data=payload, table_name="spend")
|
||||
prisma_client.spend_log_transactons.append(payload)
|
||||
# if db_writer_client is not None:
|
||||
# print("Tries to make call")
|
||||
# response = await db_writer_client.post(
|
||||
# url="http://0.0.0.0:3000/spend/update",
|
||||
# data=json.dumps(payload),
|
||||
# headers={"Content-Type": "application/json"},
|
||||
# )
|
||||
# print(f"response: {response}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
|
@ -2693,7 +2703,7 @@ def on_backoff(details):
|
|||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client
|
||||
import json
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
|
@ -2726,6 +2736,8 @@ async def startup_event():
|
|||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
|
||||
db_writer_client = HTTPHandler()
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
## JWT AUTH ##
|
||||
|
@ -2836,7 +2848,7 @@ async def startup_event():
|
|||
update_spend,
|
||||
"interval",
|
||||
seconds=batch_writing_interval,
|
||||
args=[prisma_client],
|
||||
args=[prisma_client, db_writer_client],
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
|
@ -7985,6 +7997,8 @@ async def shutdown_event():
|
|||
|
||||
await jwt_handler.close()
|
||||
|
||||
if db_writer_client is not None:
|
||||
await db_writer_client.close()
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
cleanup_router_config_variables()
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from litellm.proxy._types import (
|
|||
Member,
|
||||
)
|
||||
from litellm.caching import DualCache
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||
_PROXY_MaxParallelRequestsHandler,
|
||||
)
|
||||
|
@ -1716,6 +1717,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
|||
# hash the api_key
|
||||
api_key = hash_token(api_key)
|
||||
|
||||
# jsonify datetime object
|
||||
# if isinstance(start_time, datetime):
|
||||
# start_time = start_time.isoformat()
|
||||
# if isinstance(end_time, datetime):
|
||||
# end_time = end_time.isoformat()
|
||||
# clean up litellm metadata
|
||||
if isinstance(metadata, dict):
|
||||
clean_metadata = {}
|
||||
|
@ -1866,9 +1872,7 @@ async def reset_budget(prisma_client: PrismaClient):
|
|||
)
|
||||
|
||||
|
||||
async def update_spend(
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
async def update_spend(prisma_client: PrismaClient, db_writer_client: HTTPHandler):
|
||||
"""
|
||||
Batch write updates to db.
|
||||
|
||||
|
@ -1995,13 +1999,23 @@ async def update_spend(
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
### UPDATE SPEND LOGS ###
|
||||
# if len(prisma_client.spend_log_transactons) > 0:
|
||||
# response = await db_writer_client.post(
|
||||
# url="http://0.0.0.0:3000/spend/update",
|
||||
# data=prisma_client.spend_log_transactons,
|
||||
# headers={"Content-Type": "application/json"},
|
||||
# )
|
||||
# if response.status_code == 200:
|
||||
# prisma_client.spend_log_transactons = []
|
||||
|
||||
async def monitor_spend_list(prisma_client: PrismaClient):
|
||||
"""
|
||||
Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db
|
||||
"""
|
||||
if len(prisma_client.user_list_transactons) > 10000:
|
||||
await update_spend(prisma_client=prisma_client)
|
||||
|
||||
# async def monitor_spend_list(prisma_client: PrismaClient):
|
||||
# """
|
||||
# Check the length of each spend list, if it exceeds a threshold (e.g. 100 items) - write to db
|
||||
# """
|
||||
# if len(prisma_client.user_list_transactons) > 10000:
|
||||
# await update_spend(prisma_client=prisma_client)
|
||||
|
||||
|
||||
async def _read_request_body(request):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue