mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(main.py): return async completion calls
This commit is contained in:
parent
6edc7cc2b3
commit
34509d8dda
3 changed files with 65 additions and 38 deletions
|
@ -284,7 +284,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def streaming(self,
|
def streaming(self,
|
||||||
|
@ -631,24 +631,27 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str):
|
model: str):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
|
try:
|
||||||
response_json = response.json()
|
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
|
||||||
if response.status_code != 200:
|
response_json = response.json()
|
||||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
if response.status_code != 200:
|
||||||
|
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
## LOGGING
|
||||||
input=prompt,
|
logging_obj.post_call(
|
||||||
api_key=api_key,
|
input=prompt,
|
||||||
original_response=response,
|
api_key=api_key,
|
||||||
additional_args={
|
original_response=response,
|
||||||
"headers": headers,
|
additional_args={
|
||||||
"api_base": api_base,
|
"headers": headers,
|
||||||
},
|
"api_base": api_base,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def streaming(self,
|
def streaming(self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
@ -687,9 +690,12 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
method="POST",
|
method="POST",
|
||||||
timeout=litellm.request_timeout
|
timeout=litellm.request_timeout
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code != 200:
|
try:
|
||||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
if response.status_code != 200:
|
||||||
|
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
|
|
||||||
async for transformed_chunk in streamwrapper:
|
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
|
||||||
yield transformed_chunk
|
async for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
|
@ -2205,7 +2205,8 @@ def text_completion(
|
||||||
if stream == True or kwargs.get("stream", False) == True:
|
if stream == True or kwargs.get("stream", False) == True:
|
||||||
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
||||||
return response
|
return response
|
||||||
|
if kwargs.get("acompletion", False) == True:
|
||||||
|
return response
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
# only supported for TGI models
|
# only supported for TGI models
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -169,17 +169,37 @@ def test_text_completion_stream():
|
||||||
|
|
||||||
# test_text_completion_stream()
|
# test_text_completion_stream()
|
||||||
|
|
||||||
async def test_text_completion_async_stream():
|
# async def test_text_completion_async_stream():
|
||||||
try:
|
# try:
|
||||||
response = await atext_completion(
|
# response = await atext_completion(
|
||||||
model="text-completion-openai/text-davinci-003",
|
# model="text-completion-openai/text-davinci-003",
|
||||||
prompt="good morning",
|
# prompt="good morning",
|
||||||
stream=True,
|
# stream=True,
|
||||||
max_tokens=10,
|
# max_tokens=10,
|
||||||
)
|
# )
|
||||||
async for chunk in response:
|
# async for chunk in response:
|
||||||
print(f"chunk: {chunk}")
|
# print(f"chunk: {chunk}")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
pytest.fail(f"GOT exception for HF In streaming{e}")
|
# pytest.fail(f"GOT exception for HF In streaming{e}")
|
||||||
|
|
||||||
asyncio.run(test_text_completion_async_stream())
|
# asyncio.run(test_text_completion_async_stream())
|
||||||
|
|
||||||
|
def test_async_text_completion():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
print('test_async_text_completion')
|
||||||
|
async def test_get_response():
|
||||||
|
try:
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="gpt-3.5-turbo-instruct",
|
||||||
|
prompt="good morning",
|
||||||
|
stream=False,
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
print(e)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
asyncio.run(test_get_response())
|
||||||
|
test_async_text_completion()
|
Loading…
Add table
Add a link
Reference in a new issue