mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(openai.py): working openai chat + text completion for openai v1 sdk
This commit is contained in:
parent
6d815d98fe
commit
d0bd932b3c
6 changed files with 30 additions and 27 deletions
|
@ -245,11 +245,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
data: dict, headers: dict,
|
||||
model_response: ModelResponse):
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=600) as client:
|
||||
response = await client.post(api_base, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status != 200:
|
||||
raise OpenAIError(status_code=response.status, message=response.text)
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
|
||||
## RESPONSE OBJECT
|
||||
|
|
|
@ -525,9 +525,6 @@ def completion(
|
|||
)
|
||||
raise e
|
||||
|
||||
if optional_params.get("stream", False) and acompletion is False:
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
|
||||
return response
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
|
|
@ -28,14 +28,13 @@ def test_async_response():
|
|||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages)
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
response = asyncio.run(test_get_response())
|
||||
print(response)
|
||||
# test_async_response()
|
||||
asyncio.run(test_get_response())
|
||||
test_async_response()
|
||||
|
||||
def test_get_response_streaming():
|
||||
import asyncio
|
||||
|
@ -43,7 +42,7 @@ def test_get_response_streaming():
|
|||
user_message = "write a short poem in one sentence"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages, stream=True)
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||
print(type(response))
|
||||
|
||||
import inspect
|
||||
|
|
|
@ -67,7 +67,7 @@ def test_context_window_with_fallbacks(model):
|
|||
|
||||
# for model in litellm.models_by_provider["bedrock"]:
|
||||
# test_context_window(model=model)
|
||||
# test_context_window(model="gpt-3.5-turbo")
|
||||
test_context_window(model="azure/chatgpt-v-2")
|
||||
# test_context_window_with_fallbacks(model="gpt-3.5-turbo")
|
||||
# Test 2: InvalidAuth Errors
|
||||
@pytest.mark.parametrize("model", models)
|
||||
|
|
|
@ -131,7 +131,6 @@ def streaming_format_tests(idx, chunk):
|
|||
if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk
|
||||
validate_last_format(chunk=chunk)
|
||||
finished = True
|
||||
print(f"chunk choices: {chunk['choices'][0]['delta']['content']}")
|
||||
if "content" in chunk["choices"][0]["delta"]:
|
||||
extracted_chunk = chunk["choices"][0]["delta"]["content"]
|
||||
print(f"extracted chunk: {extracted_chunk}")
|
||||
|
@ -837,6 +836,7 @@ def test_openai_chat_completion_call():
|
|||
start_time = time.time()
|
||||
for idx, chunk in enumerate(response):
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
print(f"outside chunk: {chunk}")
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
|
|
|
@ -4549,7 +4549,7 @@ class CustomStreamWrapper:
|
|||
except StopIteration:
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.print_exc()
|
||||
traceback_exception = traceback.format_exc()
|
||||
e.message = str(e)
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
||||
|
@ -4557,17 +4557,24 @@ class CustomStreamWrapper:
|
|||
|
||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||
def __next__(self):
|
||||
while True: # loop until a non-empty string is found
|
||||
try:
|
||||
# if isinstance(self.completion_stream, str):
|
||||
# chunk = self.completion_stream
|
||||
# else:
|
||||
chunk = next(self.completion_stream)
|
||||
response = self.chunk_creator(chunk=chunk)
|
||||
# if response is not None:
|
||||
return response
|
||||
except Exception as e:
|
||||
raise StopIteration
|
||||
try:
|
||||
while True:
|
||||
if isinstance(self.completion_stream, str):
|
||||
chunk = self.completion_stream
|
||||
else:
|
||||
chunk = next(self.completion_stream)
|
||||
|
||||
if chunk is not None:
|
||||
response = self.chunk_creator(chunk=chunk)
|
||||
if response is not None:
|
||||
return response
|
||||
except StopIteration:
|
||||
raise # Re-raise StopIteration
|
||||
except Exception as e:
|
||||
# Handle other exceptions if needed
|
||||
pass
|
||||
|
||||
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue