forked from phoenix/litellm-mirror
feat(utils.py): add async success callbacks for custom functions
This commit is contained in:
parent
b90fcbdac4
commit
e0ccb281d8
8 changed files with 232 additions and 138 deletions
|
@ -272,10 +272,16 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
|||
|
||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
||||
global master_key, prisma_client, llm_model_list
|
||||
print(f"master_key - {master_key}; api_key - {api_key}")
|
||||
if master_key is None:
|
||||
return {
|
||||
"api_key": None
|
||||
}
|
||||
if isinstance(api_key, str):
|
||||
return {
|
||||
"api_key": api_key.replace("Bearer ", "")
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"api_key": api_key
|
||||
}
|
||||
try:
|
||||
if api_key is None:
|
||||
raise Exception("No api key passed in.")
|
||||
|
@ -382,8 +388,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
|||
print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`")
|
||||
|
||||
def cost_tracking():
|
||||
global prisma_client, master_key
|
||||
if prisma_client is not None and master_key is not None:
|
||||
global prisma_client
|
||||
if prisma_client is not None:
|
||||
if isinstance(litellm.success_callback, list):
|
||||
print("setting litellm success callback to track cost")
|
||||
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||
|
@ -391,7 +397,7 @@ def cost_tracking():
|
|||
else:
|
||||
litellm.success_callback = track_cost_callback # type: ignore
|
||||
|
||||
def track_cost_callback(
|
||||
async def track_cost_callback(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: litellm.ModelResponse, # response from completion
|
||||
start_time = None,
|
||||
|
@ -420,31 +426,13 @@ def track_cost_callback(
|
|||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||
print("regular response_cost", response_cost)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
# asyncio.run(update_prisma_database(user_api_key, response_cost))
|
||||
# Create new event loop for async function execution in the new thread
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
# Run the async function using the newly created event loop
|
||||
existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key))
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
print(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend}))
|
||||
print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}")
|
||||
except Exception as e:
|
||||
print(f"error in creating async loop - {str(e)}")
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
except Exception as e:
|
||||
print(f"error in tracking cost callback - {str(e)}")
|
||||
|
||||
async def update_prisma_database(token, response_cost):
|
||||
|
||||
try:
|
||||
print(f"Enters prisma db call, token: {token}")
|
||||
# Fetch the existing cost for the given token
|
||||
|
@ -460,8 +448,6 @@ async def update_prisma_database(token, response_cost):
|
|||
print(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
||||
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating Prisma database: {traceback.format_exc()}")
|
||||
pass
|
||||
|
@ -648,7 +634,7 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id}
|
||||
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
|
||||
|
||||
async def delete_verification_token(tokens: List):
|
||||
global prisma_client
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue