fix(main.py): return async completion calls

This commit is contained in:
Krrish Dholakia 2023-12-18 17:41:41 -08:00
parent 6edc7cc2b3
commit 34509d8dda
3 changed files with 65 additions and 38 deletions

View file

@ -284,7 +284,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e: except Exception as e:
raise e raise e
def streaming(self, def streaming(self,
@ -631,24 +631,27 @@ class OpenAITextCompletion(BaseLLM):
api_key: str, api_key: str,
model: str): model: str):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) try:
response_json = response.json() response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
if response.status_code != 200: response_json = response.json()
raise OpenAIError(status_code=response.status_code, message=response.text) if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call( ## LOGGING
input=prompt, logging_obj.post_call(
api_key=api_key, input=prompt,
original_response=response, api_key=api_key,
additional_args={ original_response=response,
"headers": headers, additional_args={
"api_base": api_base, "headers": headers,
}, "api_base": api_base,
) },
)
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except Exception as e:
raise e
def streaming(self, def streaming(self,
logging_obj, logging_obj,
@ -687,9 +690,12 @@ class OpenAITextCompletion(BaseLLM):
method="POST", method="POST",
timeout=litellm.request_timeout timeout=litellm.request_timeout
) as response: ) as response:
if response.status_code != 200: try:
raise OpenAIError(status_code=response.status_code, message=response.text) if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
yield transformed_chunk async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e

View file

@ -2205,7 +2205,8 @@ def text_completion(
if stream == True or kwargs.get("stream", False) == True: if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model) response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response return response
if kwargs.get("acompletion", False) == True:
return response
transformed_logprobs = None transformed_logprobs = None
# only supported for TGI models # only supported for TGI models
try: try:

View file

@ -169,17 +169,37 @@ def test_text_completion_stream():
# test_text_completion_stream() # test_text_completion_stream()
async def test_text_completion_async_stream(): # async def test_text_completion_async_stream():
try: # try:
response = await atext_completion( # response = await atext_completion(
model="text-completion-openai/text-davinci-003", # model="text-completion-openai/text-davinci-003",
prompt="good morning", # prompt="good morning",
stream=True, # stream=True,
max_tokens=10, # max_tokens=10,
) # )
async for chunk in response: # async for chunk in response:
print(f"chunk: {chunk}") # print(f"chunk: {chunk}")
except Exception as e: # except Exception as e:
pytest.fail(f"GOT exception for HF In streaming{e}") # pytest.fail(f"GOT exception for HF In streaming{e}")
asyncio.run(test_text_completion_async_stream()) # asyncio.run(test_text_completion_async_stream())
def test_async_text_completion():
litellm.set_verbose = True
print('test_async_text_completion')
async def test_get_response():
try:
response = await litellm.atext_completion(
model="gpt-3.5-turbo-instruct",
prompt="good morning",
stream=False,
max_tokens=10
)
print(f"response: {response}")
except litellm.Timeout as e:
print(e)
except Exception as e:
print(e)
asyncio.run(test_get_response())
test_async_text_completion()