support stream_options

This commit is contained in:
Ishaan Jaff 2024-05-08 21:53:33 -07:00
parent f2965660dd
commit 80ca011a64

View file

@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject):
system_fingerprint=None, system_fingerprint=None,
usage=None, usage=None,
stream=None, stream=None,
stream_options=None,
response_ms=None, response_ms=None,
hidden_params=None, hidden_params=None,
**params, **params,
@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject):
usage = usage usage = usage
elif stream is None or stream == False: elif stream is None or stream == False:
usage = Usage() usage = Usage()
elif (
stream == True
and stream_options is not None
and stream_options.get("include_usage") == True
):
usage = Usage()
if hidden_params: if hidden_params:
self._hidden_params = hidden_params self._hidden_params = hidden_params
@ -4839,6 +4846,7 @@ def get_optional_params(
top_p=None, top_p=None,
n=None, n=None,
stream=False, stream=False,
stream_options=None,
stop=None, stop=None,
max_tokens=None, max_tokens=None,
presence_penalty=None, presence_penalty=None,
@ -4908,6 +4916,7 @@ def get_optional_params(
"top_p": None, "top_p": None,
"n": None, "n": None,
"stream": None, "stream": None,
"stream_options": None,
"stop": None, "stop": None,
"max_tokens": None, "max_tokens": None,
"presence_penalty": None, "presence_penalty": None,
@ -5779,6 +5788,8 @@ def get_optional_params(
optional_params["n"] = n optional_params["n"] = n
if stream is not None: if stream is not None:
optional_params["stream"] = stream optional_params["stream"] = stream
if stream_options is not None:
optional_params["stream_options"] = stream_options
if stop is not None: if stop is not None:
optional_params["stop"] = stop optional_params["stop"] = stop
if max_tokens is not None: if max_tokens is not None:
@ -6049,6 +6060,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"top_p", "top_p",
"n", "n",
"stream", "stream",
"stream_options",
"stop", "stop",
"max_tokens", "max_tokens",
"presence_penalty", "presence_penalty",
@ -9466,7 +9478,12 @@ def get_secret(
# replicate/anthropic/cohere # replicate/anthropic/cohere
class CustomStreamWrapper: class CustomStreamWrapper:
def __init__( def __init__(
self, completion_stream, model, custom_llm_provider=None, logging_obj=None self,
completion_stream,
model,
custom_llm_provider=None,
logging_obj=None,
stream_options=None,
): ):
self.model = model self.model = model
self.custom_llm_provider = custom_llm_provider self.custom_llm_provider = custom_llm_provider
@ -9492,6 +9509,7 @@ class CustomStreamWrapper:
self.response_id = None self.response_id = None
self.logging_loop = None self.logging_loop = None
self.rules = Rules() self.rules = Rules()
self.stream_options = stream_options
def __iter__(self): def __iter__(self):
return self return self
@ -9932,6 +9950,7 @@ class CustomStreamWrapper:
is_finished = False is_finished = False
finish_reason = None finish_reason = None
logprobs = None logprobs = None
usage = None
original_chunk = None # this is used for function/tool calling original_chunk = None # this is used for function/tool calling
if len(str_line.choices) > 0: if len(str_line.choices) > 0:
if ( if (
@ -9966,12 +9985,15 @@ class CustomStreamWrapper:
else: else:
logprobs = None logprobs = None
usage = getattr(str_line, "usage", None)
return { return {
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"logprobs": logprobs, "logprobs": logprobs,
"original_chunk": str_line, "original_chunk": str_line,
"usage": usage,
} }
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -10274,7 +10296,9 @@ class CustomStreamWrapper:
raise e raise e
def model_response_creator(self): def model_response_creator(self):
model_response = ModelResponse(stream=True, model=self.model) model_response = ModelResponse(
stream=True, model=self.model, stream_options=self.stream_options
)
if self.response_id is not None: if self.response_id is not None:
model_response.id = self.response_id model_response.id = self.response_id
else: else:
@ -10594,6 +10618,12 @@ class CustomStreamWrapper:
if response_obj["logprobs"] is not None: if response_obj["logprobs"] is not None:
model_response.choices[0].logprobs = response_obj["logprobs"] model_response.choices[0].logprobs = response_obj["logprobs"]
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
):
model_response.usage = response_obj["usage"]
model_response.model = self.model model_response.model = self.model
print_verbose( print_verbose(
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}" f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
@ -10681,6 +10711,11 @@ class CustomStreamWrapper:
except Exception as e: except Exception as e:
model_response.choices[0].delta = Delta() model_response.choices[0].delta = Delta()
else: else:
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
):
return model_response
return return
print_verbose( print_verbose(
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}" f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"