forked from phoenix/litellm-mirror
Merge pull request #3537 from BerriAI/litellm_support_stream_options_param
[Feat] support `stream_options` param for OpenAI
This commit is contained in:
commit
0b1885ca99
6 changed files with 101 additions and 7 deletions
|
@ -83,6 +83,7 @@ def completion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
@ -139,6 +140,10 @@ def completion(
|
||||||
|
|
||||||
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
||||||
|
|
||||||
|
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
|
||||||
|
|
||||||
|
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||||
|
|
||||||
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
|
||||||
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
||||||
|
|
|
@ -530,6 +530,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
|
@ -579,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
except (
|
except (
|
||||||
|
|
|
@ -187,6 +187,7 @@ async def acompletion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
@ -206,6 +207,7 @@ async def acompletion(
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
# Optional liteLLM function params
|
# Optional liteLLM function params
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -223,6 +225,7 @@ async def acompletion(
|
||||||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||||
n (int, optional): The number of completions to generate (default is 1).
|
n (int, optional): The number of completions to generate (default is 1).
|
||||||
stream (bool, optional): If True, return a streaming response (default is False).
|
stream (bool, optional): If True, return a streaming response (default is False).
|
||||||
|
stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True.
|
||||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||||
|
@ -260,6 +263,7 @@ async def acompletion(
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"n": n,
|
"n": n,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
|
"stream_options": stream_options,
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"presence_penalty": presence_penalty,
|
"presence_penalty": presence_penalty,
|
||||||
|
@ -457,6 +461,7 @@ def completion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
@ -496,6 +501,7 @@ def completion(
|
||||||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||||
n (int, optional): The number of completions to generate (default is 1).
|
n (int, optional): The number of completions to generate (default is 1).
|
||||||
stream (bool, optional): If True, return a streaming response (default is False).
|
stream (bool, optional): If True, return a streaming response (default is False).
|
||||||
|
stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true.
|
||||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||||
|
@ -573,6 +579,7 @@ def completion(
|
||||||
"top_p",
|
"top_p",
|
||||||
"n",
|
"n",
|
||||||
"stream",
|
"stream",
|
||||||
|
"stream_options",
|
||||||
"stop",
|
"stop",
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
|
@ -785,6 +792,7 @@ def completion(
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
n=n,
|
n=n,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
|
||||||
def test_acompletion_params():
|
def test_acompletion_params():
|
||||||
|
@ -7,17 +8,29 @@ def test_acompletion_params():
|
||||||
from litellm.types.completion import CompletionRequest
|
from litellm.types.completion import CompletionRequest
|
||||||
|
|
||||||
acompletion_params_odict = inspect.signature(acompletion).parameters
|
acompletion_params_odict = inspect.signature(acompletion).parameters
|
||||||
acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()}
|
completion_params_dict = inspect.signature(completion).parameters
|
||||||
completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()}
|
|
||||||
|
|
||||||
# remove kwargs
|
acompletion_params = {
|
||||||
acompletion_params.pop("kwargs", None)
|
name: param.annotation for name, param in acompletion_params_odict.items()
|
||||||
|
}
|
||||||
|
completion_params = {
|
||||||
|
name: param.annotation for name, param in completion_params_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
keys_acompletion = set(acompletion_params.keys())
|
keys_acompletion = set(acompletion_params.keys())
|
||||||
keys_completion = set(completion_params.keys())
|
keys_completion = set(completion_params.keys())
|
||||||
|
|
||||||
|
print(keys_acompletion)
|
||||||
|
print("\n\n\n")
|
||||||
|
print(keys_completion)
|
||||||
|
|
||||||
|
print("diff=", keys_completion - keys_acompletion)
|
||||||
|
|
||||||
# Assert that the parameters are the same
|
# Assert that the parameters are the same
|
||||||
if keys_acompletion != keys_completion:
|
if keys_acompletion != keys_completion:
|
||||||
pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.")
|
pytest.fail(
|
||||||
|
"The parameters of the litellm.acompletion function and litellm.completion are not the same."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# test_acompletion_params()
|
# test_acompletion_params()
|
||||||
|
|
|
@ -1501,6 +1501,37 @@ def test_openai_chat_completion_complete_response_call():
|
||||||
|
|
||||||
|
|
||||||
# test_openai_chat_completion_complete_response_call()
|
# test_openai_chat_completion_complete_response_call()
|
||||||
|
def test_openai_stream_options_call():
|
||||||
|
litellm.set_verbose = False
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "system", "content": "say GM - we're going to make it "}],
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
usage = None
|
||||||
|
chunks = []
|
||||||
|
for chunk in response:
|
||||||
|
print("chunk: ", chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
last_chunk = chunks[-1]
|
||||||
|
print("last chunk: ", last_chunk)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Assert that:
|
||||||
|
- Last Chunk includes Usage
|
||||||
|
- All chunks prior to last chunk have usage=None
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert last_chunk.usage is not None
|
||||||
|
assert last_chunk.usage.total_tokens > 0
|
||||||
|
assert last_chunk.usage.prompt_tokens > 0
|
||||||
|
assert last_chunk.usage.completion_tokens > 0
|
||||||
|
|
||||||
|
# assert all non last chunks have usage=None
|
||||||
|
assert all(chunk.usage is None for chunk in chunks[:-1])
|
||||||
|
|
||||||
|
|
||||||
def test_openai_text_completion_call():
|
def test_openai_text_completion_call():
|
||||||
|
|
|
@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject):
|
||||||
system_fingerprint=None,
|
system_fingerprint=None,
|
||||||
usage=None,
|
usage=None,
|
||||||
stream=None,
|
stream=None,
|
||||||
|
stream_options=None,
|
||||||
response_ms=None,
|
response_ms=None,
|
||||||
hidden_params=None,
|
hidden_params=None,
|
||||||
**params,
|
**params,
|
||||||
|
@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject):
|
||||||
usage = usage
|
usage = usage
|
||||||
elif stream is None or stream == False:
|
elif stream is None or stream == False:
|
||||||
usage = Usage()
|
usage = Usage()
|
||||||
|
elif (
|
||||||
|
stream == True
|
||||||
|
and stream_options is not None
|
||||||
|
and stream_options.get("include_usage") == True
|
||||||
|
):
|
||||||
|
usage = Usage()
|
||||||
if hidden_params:
|
if hidden_params:
|
||||||
self._hidden_params = hidden_params
|
self._hidden_params = hidden_params
|
||||||
|
|
||||||
|
@ -4839,6 +4846,7 @@ def get_optional_params(
|
||||||
top_p=None,
|
top_p=None,
|
||||||
n=None,
|
n=None,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
stream_options=None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
presence_penalty=None,
|
presence_penalty=None,
|
||||||
|
@ -4908,6 +4916,7 @@ def get_optional_params(
|
||||||
"top_p": None,
|
"top_p": None,
|
||||||
"n": None,
|
"n": None,
|
||||||
"stream": None,
|
"stream": None,
|
||||||
|
"stream_options": None,
|
||||||
"stop": None,
|
"stop": None,
|
||||||
"max_tokens": None,
|
"max_tokens": None,
|
||||||
"presence_penalty": None,
|
"presence_penalty": None,
|
||||||
|
@ -5779,6 +5788,8 @@ def get_optional_params(
|
||||||
optional_params["n"] = n
|
optional_params["n"] = n
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
if stream_options is not None:
|
||||||
|
optional_params["stream_options"] = stream_options
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
optional_params["stop"] = stop
|
optional_params["stop"] = stop
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
|
@ -6085,6 +6096,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"top_p",
|
"top_p",
|
||||||
"n",
|
"n",
|
||||||
"stream",
|
"stream",
|
||||||
|
"stream_options",
|
||||||
"stop",
|
"stop",
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
|
@ -9502,7 +9514,12 @@ def get_secret(
|
||||||
# replicate/anthropic/cohere
|
# replicate/anthropic/cohere
|
||||||
class CustomStreamWrapper:
|
class CustomStreamWrapper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, completion_stream, model, custom_llm_provider=None, logging_obj=None
|
self,
|
||||||
|
completion_stream,
|
||||||
|
model,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
logging_obj=None,
|
||||||
|
stream_options=None,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.custom_llm_provider = custom_llm_provider
|
self.custom_llm_provider = custom_llm_provider
|
||||||
|
@ -9528,6 +9545,7 @@ class CustomStreamWrapper:
|
||||||
self.response_id = None
|
self.response_id = None
|
||||||
self.logging_loop = None
|
self.logging_loop = None
|
||||||
self.rules = Rules()
|
self.rules = Rules()
|
||||||
|
self.stream_options = stream_options
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -9968,6 +9986,7 @@ class CustomStreamWrapper:
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
usage = None
|
||||||
original_chunk = None # this is used for function/tool calling
|
original_chunk = None # this is used for function/tool calling
|
||||||
if len(str_line.choices) > 0:
|
if len(str_line.choices) > 0:
|
||||||
if (
|
if (
|
||||||
|
@ -10002,12 +10021,15 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
|
usage = getattr(str_line, "usage", None)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
"logprobs": logprobs,
|
"logprobs": logprobs,
|
||||||
"original_chunk": str_line,
|
"original_chunk": str_line,
|
||||||
|
"usage": usage,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -10310,7 +10332,9 @@ class CustomStreamWrapper:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def model_response_creator(self):
|
def model_response_creator(self):
|
||||||
model_response = ModelResponse(stream=True, model=self.model)
|
model_response = ModelResponse(
|
||||||
|
stream=True, model=self.model, stream_options=self.stream_options
|
||||||
|
)
|
||||||
if self.response_id is not None:
|
if self.response_id is not None:
|
||||||
model_response.id = self.response_id
|
model_response.id = self.response_id
|
||||||
else:
|
else:
|
||||||
|
@ -10630,6 +10654,12 @@ class CustomStreamWrapper:
|
||||||
if response_obj["logprobs"] is not None:
|
if response_obj["logprobs"] is not None:
|
||||||
model_response.choices[0].logprobs = response_obj["logprobs"]
|
model_response.choices[0].logprobs = response_obj["logprobs"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.stream_options is not None
|
||||||
|
and self.stream_options["include_usage"] == True
|
||||||
|
):
|
||||||
|
model_response.usage = response_obj["usage"]
|
||||||
|
|
||||||
model_response.model = self.model
|
model_response.model = self.model
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
|
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
|
||||||
|
@ -10717,6 +10747,11 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_response.choices[0].delta = Delta()
|
model_response.choices[0].delta = Delta()
|
||||||
else:
|
else:
|
||||||
|
if (
|
||||||
|
self.stream_options is not None
|
||||||
|
and self.stream_options["include_usage"] == True
|
||||||
|
):
|
||||||
|
return model_response
|
||||||
return
|
return
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
|
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue