refactor(openai.py): working openai chat + text completion for openai v1 sdk

This commit is contained in:
Krrish Dholakia 2023-11-11 16:25:02 -08:00
parent 6d815d98fe
commit d0bd932b3c
6 changed files with 30 additions and 27 deletions

View file

@ -245,11 +245,11 @@ class OpenAIChatCompletion(BaseLLM):
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse):
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=600) as client:
response = await client.post(api_base, json=data, headers=headers)
response_json = response.json()
if response.status != 200:
raise OpenAIError(status_code=response.status, message=response.text)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT

View file

@ -525,9 +525,6 @@ def completion(
)
raise e
if optional_params.get("stream", False) and acompletion is False:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response
## LOGGING
logging.post_call(
input=messages,

View file

@ -28,14 +28,13 @@ def test_async_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages)
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
response = asyncio.run(test_get_response())
print(response)
# test_async_response()
asyncio.run(test_get_response())
test_async_response()
def test_get_response_streaming():
import asyncio
@ -43,7 +42,7 @@ def test_get_response_streaming():
user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages, stream=True)
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response))
import inspect

View file

@ -67,7 +67,7 @@ def test_context_window_with_fallbacks(model):
# for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model)
# test_context_window(model="gpt-3.5-turbo")
test_context_window(model="azure/chatgpt-v-2")
# test_context_window_with_fallbacks(model="gpt-3.5-turbo")
# Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models)

View file

@ -131,7 +131,6 @@ def streaming_format_tests(idx, chunk):
if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk
validate_last_format(chunk=chunk)
finished = True
print(f"chunk choices: {chunk['choices'][0]['delta']['content']}")
if "content" in chunk["choices"][0]["delta"]:
extracted_chunk = chunk["choices"][0]["delta"]["content"]
print(f"extracted chunk: {extracted_chunk}")
@ -837,6 +836,7 @@ def test_openai_chat_completion_call():
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
print(f"outside chunk: {chunk}")
if finished:
break
complete_response += chunk

View file

@ -4549,7 +4549,7 @@ class CustomStreamWrapper:
except StopIteration:
raise StopIteration
except Exception as e:
traceback_exception = traceback.print_exc()
traceback_exception = traceback.format_exc()
e.message = str(e)
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
@ -4557,17 +4557,24 @@ class CustomStreamWrapper:
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
while True: # loop until a non-empty string is found
try:
# if isinstance(self.completion_stream, str):
# chunk = self.completion_stream
# else:
chunk = next(self.completion_stream)
response = self.chunk_creator(chunk=chunk)
# if response is not None:
return response
except Exception as e:
raise StopIteration
try:
while True:
if isinstance(self.completion_stream, str):
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk is not None:
response = self.chunk_creator(chunk=chunk)
if response is not None:
return response
except StopIteration:
raise # Re-raise StopIteration
except Exception as e:
# Handle other exceptions if needed
pass
async def __anext__(self):
try: