test async streaming

This commit is contained in:
Krrish Dholakia 2023-09-04 15:42:22 -07:00
parent fe4caf5c3d
commit e2c143dfbc
5 changed files with 61 additions and 1 deletions

View file

@ -31,3 +31,41 @@ response = asyncio.run(test_get_response())
print(response) print(response)
``` ```
## Async Streaming
We've implemented an `__anext__()` function in the streaming object returned. This
### Usage
```
from litellm import acompletion
import asyncio
def logger_fn(model_call_object: dict):
print(f"LOGGER FUNCTION: {model_call_object}")
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
# # test on ai21 completion call
async def ai21_async_completion_call():
try:
response = completion(
model="j2-ultra", messages=messages, stream=True, logger_fn=logger_fn
)
print(f"response: {response}")
complete_response = ""
start_time = time.time()
# Change for loop to async for loop
async for chunk in response:
chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
asyncio.run(ai21_async_completion_call())
```

View file

@ -9,7 +9,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import completion from litellm import completion, acompletion
litellm.logging = False litellm.logging = False
litellm.set_verbose = False litellm.set_verbose = False
@ -217,3 +217,25 @@ def test_together_ai_completion_call_starcoder():
# except: # except:
# print(f"error occurred: {traceback.format_exc()}") # print(f"error occurred: {traceback.format_exc()}")
# pass # pass
#### Test Async streaming
# # test on ai21 completion call
async def ai21_async_completion_call():
try:
response = completion(
model="j2-ultra", messages=messages, stream=True, logger_fn=logger_fn
)
print(f"response: {response}")
complete_response = ""
start_time = time.time()
# Change for loop to async for loop
async for chunk in response:
chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
pass