fix(main.py): support cost calculation for text completion streaming object

This commit is contained in:
Krrish Dholakia 2024-01-08 12:41:08 +05:30
parent 442ebdde7c
commit 6333fbfe56
3 changed files with 78 additions and 2 deletions

View file

@ -3114,6 +3114,70 @@ def config_completion(**kwargs):
"No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`"
)
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=None):
id = chunks[0]["id"]
object = chunks[0]["object"]
created = chunks[0]["created"]
model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None)
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
logprobs = chunks[-1]["choices"][0]["logprobs"]
response = {
"id": id,
"object": object,
"created": created,
"model": model,
"system_fingerprint": system_fingerprint,
"choices": [
{
"text": None,
"index": 0,
"logprobs": logprobs,
"finish_reason": finish_reason
}
],
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None
}
}
content_list = []
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
if choice is not None and hasattr(choice, "text") and choice.get("text") is not None:
_choice = choice.get("text")
content_list.append(_choice)
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
combined_content = "".join(content_list)
# Update the "content" field within the response dictionary
response["choices"][0]["text"] = combined_content
if len(combined_content) > 0:
completion_output = combined_content
else:
completion_output = ""
# # Update usage information if needed
try:
response["usage"]["prompt_tokens"] = token_counter(
model=model, messages=messages
)
except: # don't allow this failing to block a complete streaming response from being returned
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = token_counter(
model=model,
text=combined_content,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
)
response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
)
return response
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
id = chunks[0]["id"]
@ -3121,6 +3185,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
created = chunks[0]["created"]
model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None)
if isinstance(chunks[0]["choices"][0], litellm.utils.TextChoices): # route to the text completion logic
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"]

View file

@ -2936,18 +2936,21 @@ async def test_async_text_completion_chat_model_stream():
stream=True,
max_tokens=10,
)
print(f"response: {response}")
num_finish_reason = 0
chunks = []
async for chunk in response:
print(chunk)
chunks.append(chunk)
if chunk["choices"][0].get("finish_reason") is not None:
num_finish_reason += 1
print("finish_reason", chunk["choices"][0].get("finish_reason"))
assert (
num_finish_reason == 1
), f"expected only one finish reason. Got {num_finish_reason}"
response_obj = litellm.stream_chunk_builder(chunks=chunks)
cost = litellm.completion_cost(completion_response=response_obj)
assert cost > 0
except Exception as e:
pytest.fail(f"GOT exception for gpt-3.5 In streaming{e}")

View file

@ -558,6 +558,13 @@ class TextChoices(OpenAIObject):
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class TextCompletionResponse(OpenAIObject):