refactor(openai.py): working openai chat + text completion for openai v1 sdk

This commit is contained in:
Krrish Dholakia 2023-11-11 16:25:02 -08:00
parent 6d815d98fe
commit d0bd932b3c
6 changed files with 30 additions and 27 deletions

View file

@ -245,11 +245,11 @@ class OpenAIChatCompletion(BaseLLM):
api_base: str, api_base: str,
data: dict, headers: dict, data: dict, headers: dict,
model_response: ModelResponse): model_response: ModelResponse):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(timeout=600) as client:
response = await client.post(api_base, json=data, headers=headers) response = await client.post(api_base, json=data, headers=headers)
response_json = response.json() response_json = response.json()
if response.status != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status, message=response.text) raise OpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT ## RESPONSE OBJECT

View file

@ -525,9 +525,6 @@ def completion(
) )
raise e raise e
if optional_params.get("stream", False) and acompletion is False:
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
return response
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,

View file

@ -28,14 +28,13 @@ def test_async_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages) response = await acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
response = asyncio.run(test_get_response()) asyncio.run(test_get_response())
print(response) test_async_response()
# test_async_response()
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio
@ -43,7 +42,7 @@ def test_get_response_streaming():
user_message = "write a short poem in one sentence" user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages, stream=True) response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response)) print(type(response))
import inspect import inspect

View file

@ -67,7 +67,7 @@ def test_context_window_with_fallbacks(model):
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model) # test_context_window(model=model)
# test_context_window(model="gpt-3.5-turbo") test_context_window(model="azure/chatgpt-v-2")
# test_context_window_with_fallbacks(model="gpt-3.5-turbo") # test_context_window_with_fallbacks(model="gpt-3.5-turbo")
# Test 2: InvalidAuth Errors # Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)

View file

@ -131,7 +131,6 @@ def streaming_format_tests(idx, chunk):
if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk
validate_last_format(chunk=chunk) validate_last_format(chunk=chunk)
finished = True finished = True
print(f"chunk choices: {chunk['choices'][0]['delta']['content']}")
if "content" in chunk["choices"][0]["delta"]: if "content" in chunk["choices"][0]["delta"]:
extracted_chunk = chunk["choices"][0]["delta"]["content"] extracted_chunk = chunk["choices"][0]["delta"]["content"]
print(f"extracted chunk: {extracted_chunk}") print(f"extracted chunk: {extracted_chunk}")
@ -837,6 +836,7 @@ def test_openai_chat_completion_call():
start_time = time.time() start_time = time.time()
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
print(f"outside chunk: {chunk}")
if finished: if finished:
break break
complete_response += chunk complete_response += chunk

View file

@ -4549,7 +4549,7 @@ class CustomStreamWrapper:
except StopIteration: except StopIteration:
raise StopIteration raise StopIteration
except Exception as e: except Exception as e:
traceback_exception = traceback.print_exc() traceback_exception = traceback.format_exc()
e.message = str(e) e.message = str(e)
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start() threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
@ -4557,17 +4557,24 @@ class CustomStreamWrapper:
## needs to handle the empty string case (even starting chunk can be an empty string) ## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self): def __next__(self):
while True: # loop until a non-empty string is found try:
try: while True:
# if isinstance(self.completion_stream, str): if isinstance(self.completion_stream, str):
# chunk = self.completion_stream chunk = self.completion_stream
# else: else:
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
response = self.chunk_creator(chunk=chunk)
# if response is not None: if chunk is not None:
return response response = self.chunk_creator(chunk=chunk)
except Exception as e: if response is not None:
raise StopIteration return response
except StopIteration:
raise # Re-raise StopIteration
except Exception as e:
# Handle other exceptions if needed
pass
async def __anext__(self): async def __anext__(self):
try: try: