From 6333fbfe5612355460254c21dd37be407d25faeb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 8 Jan 2024 12:41:08 +0530 Subject: [PATCH] fix(main.py): support cost calculation for text completion streaming object --- litellm/main.py | 66 +++++++++++++++++++++++++++ litellm/tests/test_text_completion.py | 7 ++- litellm/utils.py | 7 +++ 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 3f1902ec97..59cbbab3c2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3114,6 +3114,70 @@ def config_completion(**kwargs): "No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`" ) +def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=None): + id = chunks[0]["id"] + object = chunks[0]["object"] + created = chunks[0]["created"] + model = chunks[0]["model"] + system_fingerprint = chunks[0].get("system_fingerprint", None) + finish_reason = chunks[-1]["choices"][0]["finish_reason"] + logprobs = chunks[-1]["choices"][0]["logprobs"] + + response = { + "id": id, + "object": object, + "created": created, + "model": model, + "system_fingerprint": system_fingerprint, + "choices": [ + { + "text": None, + "index": 0, + "logprobs": logprobs, + "finish_reason": finish_reason + } + ], + "usage": { + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None + } + } + content_list = [] + for chunk in chunks: + choices = chunk["choices"] + for choice in choices: + if choice is not None and hasattr(choice, "text") and choice.get("text") is not None: + _choice = choice.get("text") + content_list.append(_choice) + + # Combine the "content" strings into a single string || combine the 'function' strings into a single string + combined_content = "".join(content_list) + + # Update the "content" field within the response dictionary + response["choices"][0]["text"] = combined_content + + if len(combined_content) > 0: + completion_output = combined_content + else: + completion_output = "" + # # Update usage information if needed + try: + response["usage"]["prompt_tokens"] = token_counter( + model=model, messages=messages + ) + except: # don't allow this failing to block a complete streaming response from being returned + print_verbose(f"token_counter failed, assuming prompt tokens is 0") + response["usage"]["prompt_tokens"] = 0 + response["usage"]["completion_tokens"] = token_counter( + model=model, + text=combined_content, + count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages + ) + response["usage"]["total_tokens"] = ( + response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] + ) + return response def stream_chunk_builder(chunks: list, messages: Optional[list] = None): id = chunks[0]["id"] @@ -3121,6 +3185,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None): created = chunks[0]["created"] model = chunks[0]["model"] system_fingerprint = chunks[0].get("system_fingerprint", None) + if isinstance(chunks[0]["choices"][0], litellm.utils.TextChoices): # route to the text completion logic + return stream_chunk_builder_text_completion(chunks=chunks, messages=messages) role = chunks[0]["choices"][0]["delta"]["role"] finish_reason = chunks[-1]["choices"][0]["finish_reason"] diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index 6e9df66087..2aef20322f 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -2936,18 +2936,21 @@ async def test_async_text_completion_chat_model_stream(): stream=True, max_tokens=10, ) - print(f"response: {response}") num_finish_reason = 0 + chunks = [] async for chunk in response: print(chunk) + chunks.append(chunk) if chunk["choices"][0].get("finish_reason") is not None: num_finish_reason += 1 - print("finish_reason", chunk["choices"][0].get("finish_reason")) assert ( num_finish_reason == 1 ), f"expected only one finish reason. Got {num_finish_reason}" + response_obj = litellm.stream_chunk_builder(chunks=chunks) + cost = litellm.completion_cost(completion_response=response_obj) + assert cost > 0 except Exception as e: pytest.fail(f"GOT exception for gpt-3.5 In streaming{e}") diff --git a/litellm/utils.py b/litellm/utils.py index da1dba897c..31bafd61fe 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -558,6 +558,13 @@ class TextChoices(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() class TextCompletionResponse(OpenAIObject):