mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(utils.py): fix cost calculation to handle tool input
This commit is contained in:
parent
9494c2cd9e
commit
1b35736797
3 changed files with 77 additions and 6 deletions
|
@ -366,6 +366,7 @@ async def track_cost_callback(
|
||||||
global prisma_client
|
global prisma_client
|
||||||
try:
|
try:
|
||||||
# check if it has collected an entire stream response
|
# check if it has collected an entire stream response
|
||||||
|
print(f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}")
|
||||||
if "complete_streaming_response" in kwargs:
|
if "complete_streaming_response" in kwargs:
|
||||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||||
completion_response=kwargs["complete_streaming_response"]
|
completion_response=kwargs["complete_streaming_response"]
|
||||||
|
@ -377,16 +378,16 @@ async def track_cost_callback(
|
||||||
completion=output_text
|
completion=output_text
|
||||||
)
|
)
|
||||||
print("streaming response_cost", response_cost)
|
print("streaming response_cost", response_cost)
|
||||||
# for non streaming responses
|
elif kwargs["stream"] == False: # for non streaming responses
|
||||||
elif kwargs["stream"] is False: # regular response
|
|
||||||
input_text = kwargs.get("messages", "")
|
input_text = kwargs.get("messages", "")
|
||||||
|
print(f"type of input_text: {type(input_text)}")
|
||||||
if isinstance(input_text, list):
|
if isinstance(input_text, list):
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
|
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
|
||||||
elif isinstance(input_text, str):
|
elif isinstance(input_text, str):
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
|
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
|
||||||
print(f"received completion response: {completion_response}")
|
print(f"received completion response: {completion_response}")
|
||||||
|
|
||||||
print("regular response_cost", response_cost)
|
print(f"regular response_cost: {response_cost}")
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and prisma_client:
|
||||||
|
|
69
litellm/tests/test_proxy_server_spend.py
Normal file
69
litellm/tests/test_proxy_server_spend.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
# import openai, json
|
||||||
|
# client = openai.OpenAI(
|
||||||
|
# api_key="sk-1234",
|
||||||
|
# base_url="http://0.0.0.0:8000"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# super_fake_messages = [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "What's the weather like in San Francisco, Tokyo, and Paris?"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "content": None,
|
||||||
|
# "role": "assistant",
|
||||||
|
# "tool_calls": [
|
||||||
|
# {
|
||||||
|
# "id": "1",
|
||||||
|
# "function": {
|
||||||
|
# "arguments": "{\"location\": \"San Francisco\", \"unit\": \"celsius\"}",
|
||||||
|
# "name": "get_current_weather"
|
||||||
|
# },
|
||||||
|
# "type": "function"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "id": "2",
|
||||||
|
# "function": {
|
||||||
|
# "arguments": "{\"location\": \"Tokyo\", \"unit\": \"celsius\"}",
|
||||||
|
# "name": "get_current_weather"
|
||||||
|
# },
|
||||||
|
# "type": "function"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "id": "3",
|
||||||
|
# "function": {
|
||||||
|
# "arguments": "{\"location\": \"Paris\", \"unit\": \"celsius\"}",
|
||||||
|
# "name": "get_current_weather"
|
||||||
|
# },
|
||||||
|
# "type": "function"
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tool_call_id": "1",
|
||||||
|
# "role": "tool",
|
||||||
|
# "name": "get_current_weather",
|
||||||
|
# "content": "{\"location\": \"San Francisco\", \"temperature\": \"90\", \"unit\": \"celsius\"}"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tool_call_id": "2",
|
||||||
|
# "role": "tool",
|
||||||
|
# "name": "get_current_weather",
|
||||||
|
# "content": "{\"location\": \"Tokyo\", \"temperature\": \"30\", \"unit\": \"celsius\"}"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tool_call_id": "3",
|
||||||
|
# "role": "tool",
|
||||||
|
# "name": "get_current_weather",
|
||||||
|
# "content": "{\"location\": \"Paris\", \"temperature\": \"50\", \"unit\": \"celsius\"}"
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# super_fake_response = client.chat.completions.create(
|
||||||
|
# model="gpt-3.5-turbo",
|
||||||
|
# messages=super_fake_messages,
|
||||||
|
# seed=1337,
|
||||||
|
# stream=False
|
||||||
|
# ) # get a new response from the model where it can see the function response
|
||||||
|
|
||||||
|
# print(json.dumps(super_fake_response.model_dump(), indent=4))
|
|
@ -1858,8 +1858,6 @@ def completion_cost(
|
||||||
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
|
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if messages != []:
|
|
||||||
prompt = " ".join([message["content"] for message in messages])
|
|
||||||
# Handle Inputs to completion_cost
|
# Handle Inputs to completion_cost
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
|
@ -1869,6 +1867,9 @@ def completion_cost(
|
||||||
completion_tokens = completion_response['usage']['completion_tokens']
|
completion_tokens = completion_response['usage']['completion_tokens']
|
||||||
model = model or completion_response['model'] # check if user passed an override for model, if it's none check completion_response['model']
|
model = model or completion_response['model'] # check if user passed an override for model, if it's none check completion_response['model']
|
||||||
else:
|
else:
|
||||||
|
if len(messages) > 0:
|
||||||
|
prompt_tokens = token_counter(model=model, messages=messages)
|
||||||
|
elif len(prompt) > 0:
|
||||||
prompt_tokens = token_counter(model=model, text=prompt)
|
prompt_tokens = token_counter(model=model, text=prompt)
|
||||||
completion_tokens = token_counter(model=model, text=completion)
|
completion_tokens = token_counter(model=model, text=completion)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue