mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
feat(utils.py): enable returning complete response when stream=true
This commit is contained in:
parent
3dae4e9cda
commit
8ee4b1f603
4 changed files with 22 additions and 7 deletions
|
@ -270,7 +270,7 @@ def completion(
|
||||||
eos_token = kwargs.get("eos_token", None)
|
eos_token = kwargs.get("eos_token", None)
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout"]
|
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
if mock_response:
|
if mock_response:
|
||||||
|
|
|
@ -24,6 +24,7 @@ function_schema = {
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_stream_chunk_builder():
|
def test_stream_chunk_builder():
|
||||||
|
litellm.set_verbose = False
|
||||||
litellm.api_key = os.environ["OPENAI_API_KEY"]
|
litellm.api_key = os.environ["OPENAI_API_KEY"]
|
||||||
response = completion(
|
response = completion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
|
@ -35,10 +36,11 @@ def test_stream_chunk_builder():
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
# print(chunk)
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print(f"chunks: {chunks}")
|
||||||
rebuilt_response = stream_chunk_builder(chunks)
|
rebuilt_response = stream_chunk_builder(chunks)
|
||||||
|
|
||||||
# exract the response from the rebuilt response
|
# exract the response from the rebuilt response
|
||||||
|
|
|
@ -902,6 +902,17 @@ def test_openai_chat_completion_call():
|
||||||
|
|
||||||
# test_openai_chat_completion_call()
|
# test_openai_chat_completion_call()
|
||||||
|
|
||||||
|
def test_openai_chat_completion_complete_response_call():
|
||||||
|
try:
|
||||||
|
complete_response = completion(
|
||||||
|
model="gpt-3.5-turbo", messages=messages, stream=True, complete_response=True
|
||||||
|
)
|
||||||
|
print(f"complete response: {complete_response}")
|
||||||
|
except:
|
||||||
|
print(f"error occurred: {traceback.format_exc()}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
test_openai_chat_completion_complete_response_call()
|
||||||
|
|
||||||
def test_openai_text_completion_call():
|
def test_openai_text_completion_call():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -949,6 +949,12 @@ def client(original_function):
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
if "stream" in kwargs and kwargs["stream"] == True:
|
if "stream" in kwargs and kwargs["stream"] == True:
|
||||||
# TODO: Add to cache for streaming
|
# TODO: Add to cache for streaming
|
||||||
|
if "complete_response" in kwargs and kwargs["complete_response"] == True:
|
||||||
|
chunks = []
|
||||||
|
for idx, chunk in enumerate(result):
|
||||||
|
chunks.append(chunk)
|
||||||
|
return litellm.stream_chunk_builder(chunks)
|
||||||
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -956,10 +962,6 @@ def client(original_function):
|
||||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
|
||||||
# [OPTIONAL] Return LiteLLM call_id
|
|
||||||
if litellm.use_client == True:
|
|
||||||
result['litellm_call_id'] = litellm_call_id
|
|
||||||
|
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||||
logging_obj.success_handler(result, start_time, end_time)
|
logging_obj.success_handler(result, start_time, end_time)
|
||||||
# threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
# threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue