diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d4ad0834d..252ccafa7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -118,6 +118,37 @@ def data_generator(response): print_verbose(f"returned chunk: {chunk}") yield f"data: {json.dumps(chunk)}\n\n" +def custom_callback( + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, end_time # start/end time +): + # Your custom code here + print("LITELLM: in custom callback function") + # print("kwargs", kwargs) + # print("start_time", start_time) + # print("end_time", end_time) + if "complete_streaming_response" in kwargs: + print("GOT COMPLETE STREAMING RESPINSE", kwargs["complete_streaming_response"]) + response_cost = litellm.completion_cost( + completion_response=kwargs["complete_streaming_response"] + ) + print("response_cost", response_cost) + else: + print("completion_response", completion_response) + response_cost = litellm.completion_cost(completion_response=completion_response) + + logging.basicConfig( + filename='cost.log', + level=logging.INFO, + format='%(asctime)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + logging.info(f"Model {completion_response.model} Cost: ${response_cost:.8f}") + + +litellm.success_callback = [custom_callback] + def litellm_completion(data, type): try: if user_model: @@ -203,22 +234,8 @@ async def chat_completion(request: Request): data = await request.json() print_verbose(f"data passed in: {data}") response = litellm_completion(data, type="chat_completion") - # track cost of this response, using litellm.completion_cost - track_cost(response) return response -async def track_cost(response): - try: - logging.basicConfig( - filename='cost.log', - level=logging.INFO, - format='%(asctime)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - response_cost = litellm.completion_cost(completion_response=response) - logging.info(f"Model {response.model} Cost: ${response_cost:.8f}") - except: - pass def print_cost_logs(): with open('cost.log', 'r') as f: