fix(main.py): fix async text completion streaming + add new tests

This commit is contained in:
Krrish Dholakia 2023-12-29 11:33:28 +05:30
parent 2b8e2bd937
commit 6f2734100f
2 changed files with 41 additions and 14 deletions

View file

@ -2472,22 +2472,22 @@ async def atext_completion(*args, **kwargs):
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False): # Await normally
response = text_completion(*args, **kwargs) response = await loop.run_in_executor(None, func_with_context)
else: if asyncio.iscoroutine(response):
# Await normally response = await response
response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(response):
response = await response
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator if kwargs.get("stream", False) == True: # return an async generator
return _async_streaming( return TextCompletionStreamWrapper(
response=response, completion_stream=_async_streaming(
response=response,
model=model,
custom_llm_provider=custom_llm_provider,
args=args,
),
model=model, model=model,
custom_llm_provider=custom_llm_provider,
args=args,
) )
else: else:
return response return response
@ -2691,11 +2691,11 @@ def text_completion(
**kwargs, **kwargs,
**optional_params, **optional_params,
) )
if kwargs.get("acompletion", False) == True:
return response
if stream == True or kwargs.get("stream", False) == True: if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model) response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response return response
if kwargs.get("acompletion", False) == True:
return response
transformed_logprobs = None transformed_logprobs = None
# only supported for TGI models # only supported for TGI models
try: try:

View file

@ -215,3 +215,30 @@ def test_get_response_non_openai_streaming():
# test_get_response_non_openai_streaming() # test_get_response_non_openai_streaming()
async def test_get_response():
try:
response = await litellm.atext_completion(
model="gpt-3.5-turbo",
prompt="good morning",
stream=True,
max_tokens=10,
)
print(f"response: {response}")
num_finish_reason = 0
async for chunk in response:
print(chunk)
if chunk["choices"][0].get("finish_reason") is not None:
num_finish_reason += 1
print("finish_reason", chunk["choices"][0].get("finish_reason"))
assert (
num_finish_reason == 1
), f"expected only one finish reason. Got {num_finish_reason}"
except Exception as e:
pytest.fail(f"GOT exception for gpt-3.5 instruct In streaming{e}")
# asyncio.run(test_get_response())