diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index d542cbe07..674cc86a2 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1205,6 +1205,7 @@ class OpenAITextCompletion(BaseLLM): model=model, custom_llm_provider="text-completion-openai", logging_obj=logging_obj, + stream_options=data.get("stream_options", None), ) for chunk in streamwrapper: @@ -1243,6 +1244,7 @@ class OpenAITextCompletion(BaseLLM): model=model, custom_llm_provider="text-completion-openai", logging_obj=logging_obj, + stream_options=data.get("stream_options", None), ) async for transformed_chunk in streamwrapper: diff --git a/litellm/main.py b/litellm/main.py index 5ab3fd7c4..3816f4f4f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3200,6 +3200,7 @@ def text_completion( Union[str, List[str]] ] = None, # Optional: Sequences where the API will stop generating further tokens. stream: Optional[bool] = None, # Optional: Whether to stream back partial progress. + stream_options: Optional[dict] = None, suffix: Optional[ str ] = None, # Optional: The suffix that comes after a completion of inserted text. @@ -3277,6 +3278,8 @@ def text_completion( optional_params["stop"] = stop if stream is not None: optional_params["stream"] = stream + if stream_options is not None: + optional_params["stream_options"] = stream_options if suffix is not None: optional_params["suffix"] = suffix if temperature is not None: @@ -3387,7 +3390,9 @@ def text_completion( if kwargs.get("acompletion", False) == True: return response if stream == True or kwargs.get("stream", False) == True: - response = TextCompletionStreamWrapper(completion_stream=response, model=model) + response = TextCompletionStreamWrapper( + completion_stream=response, model=model, stream_options=stream_options + ) return response transformed_logprobs = None # only supported for TGI models diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 7d639d7a3..93d7567eb 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1534,6 +1534,39 @@ def test_openai_stream_options_call(): assert all(chunk.usage is None for chunk in chunks[:-1]) +def test_openai_stream_options_call_text_completion(): + litellm.set_verbose = False + response = litellm.text_completion( + model="gpt-3.5-turbo-instruct", + prompt="say GM - we're going to make it ", + stream=True, + stream_options={"include_usage": True}, + max_tokens=10, + ) + usage = None + chunks = [] + for chunk in response: + print("chunk: ", chunk) + chunks.append(chunk) + + last_chunk = chunks[-1] + print("last chunk: ", last_chunk) + + """ + Assert that: + - Last Chunk includes Usage + - All chunks prior to last chunk have usage=None + """ + + assert last_chunk.usage is not None + assert last_chunk.usage.total_tokens > 0 + assert last_chunk.usage.prompt_tokens > 0 + assert last_chunk.usage.completion_tokens > 0 + + # assert all non last chunks have usage=None + assert all(chunk.usage is None for chunk in chunks[:-1]) + + def test_openai_text_completion_call(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index 6da296038..23f8ca712 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10062,16 +10062,19 @@ class CustomStreamWrapper: text = "" is_finished = False finish_reason = None + usage = None choices = getattr(chunk, "choices", []) if len(choices) > 0: text = choices[0].text if choices[0].finish_reason is not None: is_finished = True finish_reason = choices[0].finish_reason + usage = getattr(chunk, "usage", None) return { "text": text, "is_finished": is_finished, "finish_reason": finish_reason, + "usage": usage, } except Exception as e: @@ -10601,6 +10604,11 @@ class CustomStreamWrapper: print_verbose(f"completion obj content: {completion_obj['content']}") if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + if ( + self.stream_options + and self.stream_options.get("include_usage", False) == True + ): + model_response.usage = response_obj["usage"] elif self.custom_llm_provider == "azure_text": response_obj = self.handle_azure_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -11130,9 +11138,10 @@ class CustomStreamWrapper: class TextCompletionStreamWrapper: - def __init__(self, completion_stream, model): + def __init__(self, completion_stream, model, stream_options: Optional[dict] = None): self.completion_stream = completion_stream self.model = model + self.stream_options = stream_options def __iter__(self): return self @@ -11156,6 +11165,14 @@ class TextCompletionStreamWrapper: text_choices["index"] = chunk["choices"][0]["index"] text_choices["finish_reason"] = chunk["choices"][0]["finish_reason"] response["choices"] = [text_choices] + + # only pass usage when stream_options["include_usage"] is True + if ( + self.stream_options + and self.stream_options.get("include_usage", False) == True + ): + response["usage"] = chunk.get("usage", None) + return response except Exception as e: raise Exception(