feat(utils.py): enable returning complete response when stream=true

This commit is contained in:
Krrish Dholakia 2023-11-09 09:17:43 -08:00
parent 3dae4e9cda
commit 8ee4b1f603
4 changed files with 22 additions and 7 deletions

View file

@ -270,7 +270,7 @@ def completion(
eos_token = kwargs.get("eos_token", None)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:

View file

@ -24,6 +24,7 @@ function_schema = {
}
def test_stream_chunk_builder():
litellm.set_verbose = False
litellm.api_key = os.environ["OPENAI_API_KEY"]
response = completion(
model="gpt-3.5-turbo",
@ -35,10 +36,11 @@ def test_stream_chunk_builder():
chunks = []
for chunk in response:
print(chunk)
# print(chunk)
chunks.append(chunk)
try:
print(f"chunks: {chunks}")
rebuilt_response = stream_chunk_builder(chunks)
# exract the response from the rebuilt response

View file

@ -902,6 +902,17 @@ def test_openai_chat_completion_call():
# test_openai_chat_completion_call()
def test_openai_chat_completion_complete_response_call():
try:
complete_response = completion(
model="gpt-3.5-turbo", messages=messages, stream=True, complete_response=True
)
print(f"complete response: {complete_response}")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
test_openai_chat_completion_complete_response_call()
def test_openai_text_completion_call():
try:

View file

@ -949,6 +949,12 @@ def client(original_function):
end_time = datetime.datetime.now()
if "stream" in kwargs and kwargs["stream"] == True:
# TODO: Add to cache for streaming
if "complete_response" in kwargs and kwargs["complete_response"] == True:
chunks = []
for idx, chunk in enumerate(result):
chunks.append(chunk)
return litellm.stream_chunk_builder(chunks)
else:
return result
@ -956,10 +962,6 @@ def client(original_function):
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# [OPTIONAL] Return LiteLLM call_id
if litellm.use_client == True:
result['litellm_call_id'] = litellm_call_id
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
logging_obj.success_handler(result, start_time, end_time)
# threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()