diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 11ca13121..451deaac4 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -83,6 +83,7 @@ def completion( top_p: Optional[float] = None, n: Optional[int] = None, stream: Optional[bool] = None, + stream_options: Optional[dict] = None, stop=None, max_tokens: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -139,6 +140,10 @@ def completion( - `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message. +- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true` + + - `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value. + - `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens. - `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion. diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index d516334ac..d542cbe07 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -530,6 +530,7 @@ class OpenAIChatCompletion(BaseLLM): model=model, custom_llm_provider="openai", logging_obj=logging_obj, + stream_options=data.get("stream_options", None), ) return streamwrapper @@ -579,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM): model=model, custom_llm_provider="openai", logging_obj=logging_obj, + stream_options=data.get("stream_options", None), ) return streamwrapper except ( diff --git a/litellm/main.py b/litellm/main.py index 99e5ec224..99e556bfa 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -187,6 +187,7 @@ async def acompletion( top_p: Optional[float] = None, n: Optional[int] = None, stream: Optional[bool] = None, + stream_options: Optional[dict] = None, stop=None, max_tokens: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -206,6 +207,7 @@ async def acompletion( api_version: Optional[str] = None, api_key: Optional[str] = None, model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + extra_headers: Optional[dict] = None, # Optional liteLLM function params **kwargs, ): @@ -223,6 +225,7 @@ async def acompletion( top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). n (int, optional): The number of completions to generate (default is 1). stream (bool, optional): If True, return a streaming response (default is False). + stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True. stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. @@ -260,6 +263,7 @@ async def acompletion( "top_p": top_p, "n": n, "stream": stream, + "stream_options": stream_options, "stop": stop, "max_tokens": max_tokens, "presence_penalty": presence_penalty, @@ -457,6 +461,7 @@ def completion( top_p: Optional[float] = None, n: Optional[int] = None, stream: Optional[bool] = None, + stream_options: Optional[dict] = None, stop=None, max_tokens: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -496,6 +501,7 @@ def completion( top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). n (int, optional): The number of completions to generate (default is 1). stream (bool, optional): If True, return a streaming response (default is False). + stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. @@ -573,6 +579,7 @@ def completion( "top_p", "n", "stream", + "stream_options", "stop", "max_tokens", "presence_penalty", @@ -785,6 +792,7 @@ def completion( top_p=top_p, n=n, stream=stream, + stream_options=stream_options, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, diff --git a/litellm/tests/test_acompletion.py b/litellm/tests/test_acompletion.py index e5c09b9b7..b83e34653 100644 --- a/litellm/tests/test_acompletion.py +++ b/litellm/tests/test_acompletion.py @@ -1,5 +1,6 @@ import pytest from litellm import acompletion +from litellm import completion def test_acompletion_params(): @@ -7,17 +8,29 @@ def test_acompletion_params(): from litellm.types.completion import CompletionRequest acompletion_params_odict = inspect.signature(acompletion).parameters - acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()} - completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()} + completion_params_dict = inspect.signature(completion).parameters - # remove kwargs - acompletion_params.pop("kwargs", None) + acompletion_params = { + name: param.annotation for name, param in acompletion_params_odict.items() + } + completion_params = { + name: param.annotation for name, param in completion_params_dict.items() + } keys_acompletion = set(acompletion_params.keys()) keys_completion = set(completion_params.keys()) + print(keys_acompletion) + print("\n\n\n") + print(keys_completion) + + print("diff=", keys_completion - keys_acompletion) + # Assert that the parameters are the same if keys_acompletion != keys_completion: - pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.") + pytest.fail( + "The parameters of the litellm.acompletion function and litellm.completion are not the same." + ) + # test_acompletion_params() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 271a53dd4..7d639d7a3 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1501,6 +1501,37 @@ def test_openai_chat_completion_complete_response_call(): # test_openai_chat_completion_complete_response_call() +def test_openai_stream_options_call(): + litellm.set_verbose = False + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "system", "content": "say GM - we're going to make it "}], + stream=True, + stream_options={"include_usage": True}, + max_tokens=10, + ) + usage = None + chunks = [] + for chunk in response: + print("chunk: ", chunk) + chunks.append(chunk) + + last_chunk = chunks[-1] + print("last chunk: ", last_chunk) + + """ + Assert that: + - Last Chunk includes Usage + - All chunks prior to last chunk have usage=None + """ + + assert last_chunk.usage is not None + assert last_chunk.usage.total_tokens > 0 + assert last_chunk.usage.prompt_tokens > 0 + assert last_chunk.usage.completion_tokens > 0 + + # assert all non last chunks have usage=None + assert all(chunk.usage is None for chunk in chunks[:-1]) def test_openai_text_completion_call(): diff --git a/litellm/utils.py b/litellm/utils.py index 5725e4992..20716f43d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject): system_fingerprint=None, usage=None, stream=None, + stream_options=None, response_ms=None, hidden_params=None, **params, @@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject): usage = usage elif stream is None or stream == False: usage = Usage() + elif ( + stream == True + and stream_options is not None + and stream_options.get("include_usage") == True + ): + usage = Usage() if hidden_params: self._hidden_params = hidden_params @@ -4839,6 +4846,7 @@ def get_optional_params( top_p=None, n=None, stream=False, + stream_options=None, stop=None, max_tokens=None, presence_penalty=None, @@ -4908,6 +4916,7 @@ def get_optional_params( "top_p": None, "n": None, "stream": None, + "stream_options": None, "stop": None, "max_tokens": None, "presence_penalty": None, @@ -5779,6 +5788,8 @@ def get_optional_params( optional_params["n"] = n if stream is not None: optional_params["stream"] = stream + if stream_options is not None: + optional_params["stream_options"] = stream_options if stop is not None: optional_params["stop"] = stop if max_tokens is not None: @@ -6085,6 +6096,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "top_p", "n", "stream", + "stream_options", "stop", "max_tokens", "presence_penalty", @@ -9502,7 +9514,12 @@ def get_secret( # replicate/anthropic/cohere class CustomStreamWrapper: def __init__( - self, completion_stream, model, custom_llm_provider=None, logging_obj=None + self, + completion_stream, + model, + custom_llm_provider=None, + logging_obj=None, + stream_options=None, ): self.model = model self.custom_llm_provider = custom_llm_provider @@ -9528,6 +9545,7 @@ class CustomStreamWrapper: self.response_id = None self.logging_loop = None self.rules = Rules() + self.stream_options = stream_options def __iter__(self): return self @@ -9968,6 +9986,7 @@ class CustomStreamWrapper: is_finished = False finish_reason = None logprobs = None + usage = None original_chunk = None # this is used for function/tool calling if len(str_line.choices) > 0: if ( @@ -10002,12 +10021,15 @@ class CustomStreamWrapper: else: logprobs = None + usage = getattr(str_line, "usage", None) + return { "text": text, "is_finished": is_finished, "finish_reason": finish_reason, "logprobs": logprobs, "original_chunk": str_line, + "usage": usage, } except Exception as e: traceback.print_exc() @@ -10310,7 +10332,9 @@ class CustomStreamWrapper: raise e def model_response_creator(self): - model_response = ModelResponse(stream=True, model=self.model) + model_response = ModelResponse( + stream=True, model=self.model, stream_options=self.stream_options + ) if self.response_id is not None: model_response.id = self.response_id else: @@ -10630,6 +10654,12 @@ class CustomStreamWrapper: if response_obj["logprobs"] is not None: model_response.choices[0].logprobs = response_obj["logprobs"] + if ( + self.stream_options is not None + and self.stream_options["include_usage"] == True + ): + model_response.usage = response_obj["usage"] + model_response.model = self.model print_verbose( f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}" @@ -10717,6 +10747,11 @@ class CustomStreamWrapper: except Exception as e: model_response.choices[0].delta = Delta() else: + if ( + self.stream_options is not None + and self.stream_options["include_usage"] == True + ): + return model_response return print_verbose( f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"