Merge pull request #3537 from BerriAI/litellm_support_stream_options_param

[Feat] support `stream_options` param for OpenAI
This commit is contained in:
Ishaan Jaff 2024-05-09 08:34:08 -07:00 committed by GitHub
commit 0b1885ca99
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 101 additions and 7 deletions

View file

@ -83,6 +83,7 @@ def completion(
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
@ -139,6 +140,10 @@ def completion(
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.

View file

@ -530,6 +530,7 @@ class OpenAIChatCompletion(BaseLLM):
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
return streamwrapper
@ -579,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM):
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
return streamwrapper
except (

View file

@ -187,6 +187,7 @@ async def acompletion(
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
@ -206,6 +207,7 @@ async def acompletion(
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
extra_headers: Optional[dict] = None,
# Optional liteLLM function params
**kwargs,
):
@ -223,6 +225,7 @@ async def acompletion(
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
n (int, optional): The number of completions to generate (default is 1).
stream (bool, optional): If True, return a streaming response (default is False).
stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True.
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
@ -260,6 +263,7 @@ async def acompletion(
"top_p": top_p,
"n": n,
"stream": stream,
"stream_options": stream_options,
"stop": stop,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
@ -457,6 +461,7 @@ def completion(
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
@ -496,6 +501,7 @@ def completion(
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
n (int, optional): The number of completions to generate (default is 1).
stream (bool, optional): If True, return a streaming response (default is False).
stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true.
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
@ -573,6 +579,7 @@ def completion(
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
@ -785,6 +792,7 @@ def completion(
top_p=top_p,
n=n,
stream=stream,
stream_options=stream_options,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,

View file

@ -1,5 +1,6 @@
import pytest
from litellm import acompletion
from litellm import completion
def test_acompletion_params():
@ -7,17 +8,29 @@ def test_acompletion_params():
from litellm.types.completion import CompletionRequest
acompletion_params_odict = inspect.signature(acompletion).parameters
acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()}
completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()}
completion_params_dict = inspect.signature(completion).parameters
# remove kwargs
acompletion_params.pop("kwargs", None)
acompletion_params = {
name: param.annotation for name, param in acompletion_params_odict.items()
}
completion_params = {
name: param.annotation for name, param in completion_params_dict.items()
}
keys_acompletion = set(acompletion_params.keys())
keys_completion = set(completion_params.keys())
print(keys_acompletion)
print("\n\n\n")
print(keys_completion)
print("diff=", keys_completion - keys_acompletion)
# Assert that the parameters are the same
if keys_acompletion != keys_completion:
pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.")
pytest.fail(
"The parameters of the litellm.acompletion function and litellm.completion are not the same."
)
# test_acompletion_params()

View file

@ -1501,6 +1501,37 @@ def test_openai_chat_completion_complete_response_call():
# test_openai_chat_completion_complete_response_call()
def test_openai_stream_options_call():
litellm.set_verbose = False
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "say GM - we're going to make it "}],
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
usage = None
chunks = []
for chunk in response:
print("chunk: ", chunk)
chunks.append(chunk)
last_chunk = chunks[-1]
print("last chunk: ", last_chunk)
"""
Assert that:
- Last Chunk includes Usage
- All chunks prior to last chunk have usage=None
"""
assert last_chunk.usage is not None
assert last_chunk.usage.total_tokens > 0
assert last_chunk.usage.prompt_tokens > 0
assert last_chunk.usage.completion_tokens > 0
# assert all non last chunks have usage=None
assert all(chunk.usage is None for chunk in chunks[:-1])
def test_openai_text_completion_call():

View file

@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject):
system_fingerprint=None,
usage=None,
stream=None,
stream_options=None,
response_ms=None,
hidden_params=None,
**params,
@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject):
usage = usage
elif stream is None or stream == False:
usage = Usage()
elif (
stream == True
and stream_options is not None
and stream_options.get("include_usage") == True
):
usage = Usage()
if hidden_params:
self._hidden_params = hidden_params
@ -4839,6 +4846,7 @@ def get_optional_params(
top_p=None,
n=None,
stream=False,
stream_options=None,
stop=None,
max_tokens=None,
presence_penalty=None,
@ -4908,6 +4916,7 @@ def get_optional_params(
"top_p": None,
"n": None,
"stream": None,
"stream_options": None,
"stop": None,
"max_tokens": None,
"presence_penalty": None,
@ -5779,6 +5788,8 @@ def get_optional_params(
optional_params["n"] = n
if stream is not None:
optional_params["stream"] = stream
if stream_options is not None:
optional_params["stream_options"] = stream_options
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
@ -6085,6 +6096,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
@ -9502,7 +9514,12 @@ def get_secret(
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(
self, completion_stream, model, custom_llm_provider=None, logging_obj=None
self,
completion_stream,
model,
custom_llm_provider=None,
logging_obj=None,
stream_options=None,
):
self.model = model
self.custom_llm_provider = custom_llm_provider
@ -9528,6 +9545,7 @@ class CustomStreamWrapper:
self.response_id = None
self.logging_loop = None
self.rules = Rules()
self.stream_options = stream_options
def __iter__(self):
return self
@ -9968,6 +9986,7 @@ class CustomStreamWrapper:
is_finished = False
finish_reason = None
logprobs = None
usage = None
original_chunk = None # this is used for function/tool calling
if len(str_line.choices) > 0:
if (
@ -10002,12 +10021,15 @@ class CustomStreamWrapper:
else:
logprobs = None
usage = getattr(str_line, "usage", None)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"logprobs": logprobs,
"original_chunk": str_line,
"usage": usage,
}
except Exception as e:
traceback.print_exc()
@ -10310,7 +10332,9 @@ class CustomStreamWrapper:
raise e
def model_response_creator(self):
model_response = ModelResponse(stream=True, model=self.model)
model_response = ModelResponse(
stream=True, model=self.model, stream_options=self.stream_options
)
if self.response_id is not None:
model_response.id = self.response_id
else:
@ -10630,6 +10654,12 @@ class CustomStreamWrapper:
if response_obj["logprobs"] is not None:
model_response.choices[0].logprobs = response_obj["logprobs"]
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
):
model_response.usage = response_obj["usage"]
model_response.model = self.model
print_verbose(
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
@ -10717,6 +10747,11 @@ class CustomStreamWrapper:
except Exception as e:
model_response.choices[0].delta = Delta()
else:
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
):
return model_response
return
print_verbose(
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"