mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(main.py): support cost calculation for text completion streaming object
This commit is contained in:
parent
442ebdde7c
commit
6333fbfe56
3 changed files with 78 additions and 2 deletions
|
@ -3114,6 +3114,70 @@ def config_completion(**kwargs):
|
|||
"No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`"
|
||||
)
|
||||
|
||||
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=None):
|
||||
id = chunks[0]["id"]
|
||||
object = chunks[0]["object"]
|
||||
created = chunks[0]["created"]
|
||||
model = chunks[0]["model"]
|
||||
system_fingerprint = chunks[0].get("system_fingerprint", None)
|
||||
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
|
||||
logprobs = chunks[-1]["choices"][0]["logprobs"]
|
||||
|
||||
response = {
|
||||
"id": id,
|
||||
"object": object,
|
||||
"created": created,
|
||||
"model": model,
|
||||
"system_fingerprint": system_fingerprint,
|
||||
"choices": [
|
||||
{
|
||||
"text": None,
|
||||
"index": 0,
|
||||
"logprobs": logprobs,
|
||||
"finish_reason": finish_reason
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None
|
||||
}
|
||||
}
|
||||
content_list = []
|
||||
for chunk in chunks:
|
||||
choices = chunk["choices"]
|
||||
for choice in choices:
|
||||
if choice is not None and hasattr(choice, "text") and choice.get("text") is not None:
|
||||
_choice = choice.get("text")
|
||||
content_list.append(_choice)
|
||||
|
||||
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
|
||||
combined_content = "".join(content_list)
|
||||
|
||||
# Update the "content" field within the response dictionary
|
||||
response["choices"][0]["text"] = combined_content
|
||||
|
||||
if len(combined_content) > 0:
|
||||
completion_output = combined_content
|
||||
else:
|
||||
completion_output = ""
|
||||
# # Update usage information if needed
|
||||
try:
|
||||
response["usage"]["prompt_tokens"] = token_counter(
|
||||
model=model, messages=messages
|
||||
)
|
||||
except: # don't allow this failing to block a complete streaming response from being returned
|
||||
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
|
||||
response["usage"]["prompt_tokens"] = 0
|
||||
response["usage"]["completion_tokens"] = token_counter(
|
||||
model=model,
|
||||
text=combined_content,
|
||||
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
|
||||
)
|
||||
response["usage"]["total_tokens"] = (
|
||||
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
||||
)
|
||||
return response
|
||||
|
||||
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
||||
id = chunks[0]["id"]
|
||||
|
@ -3121,6 +3185,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
|||
created = chunks[0]["created"]
|
||||
model = chunks[0]["model"]
|
||||
system_fingerprint = chunks[0].get("system_fingerprint", None)
|
||||
if isinstance(chunks[0]["choices"][0], litellm.utils.TextChoices): # route to the text completion logic
|
||||
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
|
||||
role = chunks[0]["choices"][0]["delta"]["role"]
|
||||
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
|
||||
|
||||
|
|
|
@ -2936,18 +2936,21 @@ async def test_async_text_completion_chat_model_stream():
|
|||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
num_finish_reason = 0
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
chunks.append(chunk)
|
||||
if chunk["choices"][0].get("finish_reason") is not None:
|
||||
num_finish_reason += 1
|
||||
print("finish_reason", chunk["choices"][0].get("finish_reason"))
|
||||
|
||||
assert (
|
||||
num_finish_reason == 1
|
||||
), f"expected only one finish reason. Got {num_finish_reason}"
|
||||
response_obj = litellm.stream_chunk_builder(chunks=chunks)
|
||||
cost = litellm.completion_cost(completion_response=response_obj)
|
||||
assert cost > 0
|
||||
except Exception as e:
|
||||
pytest.fail(f"GOT exception for gpt-3.5 In streaming{e}")
|
||||
|
||||
|
|
|
@ -558,6 +558,13 @@ class TextChoices(OpenAIObject):
|
|||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
class TextCompletionResponse(OpenAIObject):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue