feat(utils.py): add async success callbacks for custom functions

This commit is contained in:
Krrish Dholakia 2023-12-04 16:36:21 -08:00
parent b90fcbdac4
commit e0ccb281d8
8 changed files with 232 additions and 138 deletions

View file

@ -272,10 +272,16 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
global master_key, prisma_client, llm_model_list
print(f"master_key - {master_key}; api_key - {api_key}")
if master_key is None:
return {
"api_key": None
}
if isinstance(api_key, str):
return {
"api_key": api_key.replace("Bearer ", "")
}
else:
return {
"api_key": api_key
}
try:
if api_key is None:
raise Exception("No api key passed in.")
@ -382,8 +388,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`")
def cost_tracking():
global prisma_client, master_key
if prisma_client is not None and master_key is not None:
global prisma_client
if prisma_client is not None:
if isinstance(litellm.success_callback, list):
print("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore
@ -391,7 +397,7 @@ def cost_tracking():
else:
litellm.success_callback = track_cost_callback # type: ignore
def track_cost_callback(
async def track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time = None,
@ -420,31 +426,13 @@ def track_cost_callback(
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client:
# asyncio.run(update_prisma_database(user_api_key, response_cost))
# Create new event loop for async function execution in the new thread
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
# Run the async function using the newly created event loop
existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key))
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend}))
print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}")
except Exception as e:
print(f"error in creating async loop - {str(e)}")
await update_prisma_database(token=user_api_key, response_cost=response_cost)
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost):
try:
print(f"Enters prisma db call, token: {token}")
# Fetch the existing cost for the given token
@ -460,8 +448,6 @@ async def update_prisma_database(token, response_cost):
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend})
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
except Exception as e:
print(f"Error updating Prisma database: {traceback.format_exc()}")
pass
@ -648,7 +634,7 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id}
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
async def delete_verification_token(tokens: List):
global prisma_client