diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index bb3011b3f3..033b657be3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -149,14 +149,18 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False) async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)): global master_key, prisma_client, llm_model_list if master_key is None: - return + return { + "api_key": None + } try: route = request.url.path # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key) if is_master_key_valid: - return + return { + "api_key": master_key + } if (route == "/key/generate" or route == "/key/delete") and not is_master_key_valid: raise Exception(f"If master key is set, only master key can be used to generate new keys") @@ -186,7 +190,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap llm_model_list = model_list print("\n new llm router model list", llm_model_list) if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called - return + return { + "api_key": valid_token.token + } else: data = await request.json() model = data.get("model", None) @@ -194,7 +200,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap model = litellm.model_alias_map[model] if model and model not in valid_token.models: raise Exception(f"Token not allowed to access model") - return + return { + "api_key": valid_token.token + } else: raise Exception(f"Invalid token") except Exception as e: @@ -231,6 +239,83 @@ def celery_setup(use_queue: bool): async_result = AsyncResult celery_app_conn = celery_app +def cost_tracking(): + global prisma_client, master_key + if prisma_client is not None and master_key is not None: + if isinstance(litellm.success_callback, list): + print("setting litellm success callback to track cost") + if (track_cost_callback) not in litellm.success_callback: # type: ignore + litellm.success_callback.append(track_cost_callback) # type: ignore + else: + litellm.success_callback = track_cost_callback # type: ignore + +def track_cost_callback( + kwargs, # kwargs to completion + completion_response: litellm.ModelResponse, # response from completion + start_time = None, + end_time = None, # start/end time for completion +): + try: + # init logging config + print("in custom callback tracking cost", llm_model_list) + # check if it has collected an entire stream response + if "complete_streaming_response" in kwargs: + # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost + completion_response=kwargs["complete_streaming_response"] + input_text = kwargs["messages"] + output_text = completion_response["choices"][0]["message"]["content"] + response_cost = litellm.completion_cost( + model = kwargs["model"], + messages = input_text, + completion=output_text + ) + print("streaming response_cost", response_cost) + # for non streaming responses + else: + # we pass the completion_response obj + if kwargs["stream"] != True: + input_text = kwargs.get("messages", "") + if isinstance(input_text, list): + input_text = "".join(m["content"] for m in input_text) + response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text) + print("regular response_cost", response_cost) + print(f"metadata in kwargs: {kwargs}") + user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) + if user_api_key: + asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost)) + except Exception as e: + print(f"error in tracking cost callback - {str(e)}") + +async def update_prisma_database(token, response_cost): + global prisma_client + try: + print(f"Enters prisma db call, token: {token}") + # Fetch the existing cost for the given token + existing_spend = await prisma_client.litellm_verificationtoken.find_unique( + where={ + "token": token + } + ) + print(f"existing spend: {existing_spend}") + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend.spend + response_cost + + print(f"new cost: {new_spend}") + # Update the cost column for the given token + await prisma_client.litellm_verificationtoken.update( + where={ + "token": token + }, + data={ + "spend": new_spend + } + ) + print(f"Prisma database updated for token {token}. New cost: {new_spend}") + + except Exception as e: + print(f"Error updating Prisma database: {traceback.format_exc()}") + pass + def run_ollama_serve(): try: command = ['ollama', 'serve'] @@ -272,15 +357,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) prisma_setup(database_url=database_url) - ## Cost Tracking for master key + auth setup ## - if master_key is not None: - if isinstance(litellm.success_callback, list): - import utils - print("setting litellm success callback to track cost") - if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore - litellm.success_callback.append(utils.track_cost_callback) # type: ignore - else: - litellm.success_callback = utils.track_cost_callback # type: ignore + ## COST TRACKING ## + cost_tracking() ### START REDIS QUEUE ### use_queue = general_settings.get("use_queue", False) celery_setup(use_queue=use_queue) @@ -386,12 +464,10 @@ async def delete_verification_token(tokens: List[str]): raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return deleted_tokens - async def generate_key_cli_task(duration_str): task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str)) await task - def save_worker_config(**data): import json os.environ["WORKER_CONFIG"] = json.dumps(data) @@ -487,7 +563,6 @@ def data_generator(response): except: yield f"data: {json.dumps(chunk)}\n\n" - def litellm_completion(*args, **kwargs): global user_temperature, user_request_timeout, user_max_tokens, user_api_base call_type = kwargs.pop("call_type") @@ -572,7 +647,7 @@ def model_list(): @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) -async def completion(request: Request, model: Optional[str] = None): +async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)): try: body = await request.body() body_str = body.decode() @@ -589,6 +664,7 @@ async def completion(request: Request, model: Optional[str] = None): if user_model: data["model"] = user_model data["call_type"] = "text_completion" + data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} return litellm_completion( **data ) @@ -609,7 +685,7 @@ async def completion(request: Request, model: Optional[str] = None): @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint -async def chat_completion(request: Request, model: Optional[str] = None): +async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)): global general_settings try: body = await request.body() @@ -626,6 +702,7 @@ async def chat_completion(request: Request, model: Optional[str] = None): or data["model"] # default passed in http request ) data["call_type"] = "chat_completion" + data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} return litellm_completion( **data ) @@ -644,7 +721,7 @@ async def chat_completion(request: Request, model: Optional[str] = None): @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)]) -async def embeddings(request: Request): +async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth)): try: data = await request.json() print(f"data: {data}") @@ -655,7 +732,7 @@ async def embeddings(request: Request): ) if user_model: data["model"] = user_model - + data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} ## ROUTE TO CORRECT ENDPOINT ## router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] if llm_router is not None and data["model"] in router_model_names: # model in router model list diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index c61bcfa85d..946c23e7cc 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -32,7 +32,40 @@ def track_cost_callback( else: # we pass the completion_response obj if kwargs["stream"] != True: - response_cost = litellm.completion_cost(completion_response=completion_response) + input_text = kwargs.get("messages", "") + if isinstance(input_text, list): + input_text = "".join(m["content"] for m in input_text) + response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text) print("regular response_cost", response_cost) except: - pass \ No newline at end of file + pass + +def update_prisma_database(token, response_cost): + try: + # Import your Prisma client + from your_prisma_module import prisma + + # Fetch the existing cost for the given token + existing_cost = prisma.LiteLLM_VerificationToken.find_unique( + where={ + "token": token + } + ).cost + + # Calculate the new cost by adding the existing cost and response_cost + new_cost = existing_cost + response_cost + + # Update the cost column for the given token + prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update( + where={ + "token": token + }, + data={ + "cost": new_cost + } + ) + print(f"Prisma database updated for token {token}. New cost: {new_cost}") + + except Exception as e: + print(f"Error updating Prisma database: {e}") + pass