diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 16aa4c6b4f..c6046d0595 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -366,6 +366,7 @@ async def track_cost_callback( global prisma_client try: # check if it has collected an entire stream response + print(f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}") if "complete_streaming_response" in kwargs: # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost completion_response=kwargs["complete_streaming_response"] @@ -377,16 +378,16 @@ async def track_cost_callback( completion=output_text ) print("streaming response_cost", response_cost) - # for non streaming responses - elif kwargs["stream"] is False: # regular response + elif kwargs["stream"] == False: # for non streaming responses input_text = kwargs.get("messages", "") + print(f"type of input_text: {type(input_text)}") if isinstance(input_text, list): response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text) elif isinstance(input_text, str): response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text) print(f"received completion response: {completion_response}") - print("regular response_cost", response_cost) + print(f"regular response_cost: {response_cost}") user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}") if user_api_key and prisma_client: diff --git a/litellm/tests/test_proxy_server_spend.py b/litellm/tests/test_proxy_server_spend.py new file mode 100644 index 0000000000..5e025a334a --- /dev/null +++ b/litellm/tests/test_proxy_server_spend.py @@ -0,0 +1,69 @@ +# import openai, json +# client = openai.OpenAI( +# api_key="sk-1234", +# base_url="http://0.0.0.0:8000" +# ) + +# super_fake_messages = [ +# { +# "role": "user", +# "content": "What's the weather like in San Francisco, Tokyo, and Paris?" +# }, +# { +# "content": None, +# "role": "assistant", +# "tool_calls": [ +# { +# "id": "1", +# "function": { +# "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}", +# "name": "get_current_weather" +# }, +# "type": "function" +# }, +# { +# "id": "2", +# "function": { +# "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}", +# "name": "get_current_weather" +# }, +# "type": "function" +# }, +# { +# "id": "3", +# "function": { +# "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}", +# "name": "get_current_weather" +# }, +# "type": "function" +# } +# ] +# }, +# { +# "tool_call_id": "1", +# "role": "tool", +# "name": "get_current_weather", +# "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}" +# }, +# { +# "tool_call_id": "2", +# "role": "tool", +# "name": "get_current_weather", +# "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}" +# }, +# { +# "tool_call_id": "3", +# "role": "tool", +# "name": "get_current_weather", +# "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}" +# } +# ] + +# super_fake_response = client.chat.completions.create( +# model="gpt-3.5-turbo", +# messages=super_fake_messages, +# seed=1337, +# stream=False +# ) # get a new response from the model where it can see the function response + +# print(json.dumps(super_fake_response.model_dump(), indent=4)) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index e2feb6a12c..cc11e0fbc3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1858,8 +1858,6 @@ def completion_cost( - If an error occurs during execution, the function returns 0.0 without blocking the user's execution path. """ try: - if messages != []: - prompt = " ".join([message["content"] for message in messages]) # Handle Inputs to completion_cost prompt_tokens = 0 completion_tokens = 0 @@ -1869,7 +1867,10 @@ def completion_cost( completion_tokens = completion_response['usage']['completion_tokens'] model = model or completion_response['model'] # check if user passed an override for model, if it's none check completion_response['model'] else: - prompt_tokens = token_counter(model=model, text=prompt) + if len(messages) > 0: + prompt_tokens = token_counter(model=model, messages=messages) + elif len(prompt) > 0: + prompt_tokens = token_counter(model=model, text=prompt) completion_tokens = token_counter(model=model, text=completion) # Calculate cost based on prompt_tokens, completion_tokens