fix(utils.py): await async function in client wrapper

This commit is contained in:
Krrish Dholakia 2023-11-14 22:07:28 -08:00
parent efe81032f4
commit e07bf0a8de
2 changed files with 3 additions and 4 deletions

View file

@ -42,6 +42,7 @@ def test_get_response_streaming():
user_message = "write a short poem in one sentence" user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
litellm.set_verbose = True
response = await acompletion(model="command-nightly", messages=messages, stream=True) response = await acompletion(model="command-nightly", messages=messages, stream=True)
print(type(response)) print(type(response))
@ -65,7 +66,7 @@ def test_get_response_streaming():
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_streaming() test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
import asyncio import asyncio

View file

@ -1238,7 +1238,7 @@ def client(original_function):
else: else:
return cached_result return cached_result
# MODEL CALL # MODEL CALL
result = original_function(*args, **kwargs) result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
if "stream" in kwargs and kwargs["stream"] == True: if "stream" in kwargs and kwargs["stream"] == True:
if "complete_response" in kwargs and kwargs["complete_response"] == True: if "complete_response" in kwargs and kwargs["complete_response"] == True:
@ -1248,7 +1248,6 @@ def client(original_function):
return litellm.stream_chunk_builder(chunks) return litellm.stream_chunk_builder(chunks)
else: else:
return result return result
result = await result
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs) litellm.cache.add_cache(result, *args, **kwargs)
@ -4459,7 +4458,6 @@ class CustomStreamWrapper:
traceback.print_exc() traceback.print_exc()
raise e raise e
def handle_openai_text_completion_chunk(self, chunk): def handle_openai_text_completion_chunk(self, chunk):
try: try:
str_line = chunk str_line = chunk