mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
refactor(openai.py): working openai chat + text completion for openai v1 sdk
This commit is contained in:
parent
6d815d98fe
commit
d0bd932b3c
6 changed files with 30 additions and 27 deletions
|
@ -245,11 +245,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
data: dict, headers: dict,
|
data: dict, headers: dict,
|
||||||
model_response: ModelResponse):
|
model_response: ModelResponse):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(timeout=600) as client:
|
||||||
response = await client.post(api_base, json=data, headers=headers)
|
response = await client.post(api_base, json=data, headers=headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response.status != 200:
|
if response.status_code != 200:
|
||||||
raise OpenAIError(status_code=response.status, message=response.text)
|
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
|
|
@ -524,10 +524,7 @@ def completion(
|
||||||
additional_args={"headers": headers},
|
additional_args={"headers": headers},
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if optional_params.get("stream", False) and acompletion is False:
|
|
||||||
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
|
|
||||||
return response
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
|
|
@ -28,14 +28,13 @@ def test_async_response():
|
||||||
user_message = "Hello, how are you?"
|
user_message = "Hello, how are you?"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages)
|
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
|
||||||
response = asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
print(response)
|
test_async_response()
|
||||||
# test_async_response()
|
|
||||||
|
|
||||||
def test_get_response_streaming():
|
def test_get_response_streaming():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -43,7 +42,7 @@ def test_get_response_streaming():
|
||||||
user_message = "write a short poem in one sentence"
|
user_message = "write a short poem in one sentence"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages, stream=True)
|
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||||
print(type(response))
|
print(type(response))
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
|
@ -67,7 +67,7 @@ def test_context_window_with_fallbacks(model):
|
||||||
|
|
||||||
# for model in litellm.models_by_provider["bedrock"]:
|
# for model in litellm.models_by_provider["bedrock"]:
|
||||||
# test_context_window(model=model)
|
# test_context_window(model=model)
|
||||||
# test_context_window(model="gpt-3.5-turbo")
|
test_context_window(model="azure/chatgpt-v-2")
|
||||||
# test_context_window_with_fallbacks(model="gpt-3.5-turbo")
|
# test_context_window_with_fallbacks(model="gpt-3.5-turbo")
|
||||||
# Test 2: InvalidAuth Errors
|
# Test 2: InvalidAuth Errors
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
|
|
|
@ -131,7 +131,6 @@ def streaming_format_tests(idx, chunk):
|
||||||
if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk
|
if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk
|
||||||
validate_last_format(chunk=chunk)
|
validate_last_format(chunk=chunk)
|
||||||
finished = True
|
finished = True
|
||||||
print(f"chunk choices: {chunk['choices'][0]['delta']['content']}")
|
|
||||||
if "content" in chunk["choices"][0]["delta"]:
|
if "content" in chunk["choices"][0]["delta"]:
|
||||||
extracted_chunk = chunk["choices"][0]["delta"]["content"]
|
extracted_chunk = chunk["choices"][0]["delta"]["content"]
|
||||||
print(f"extracted chunk: {extracted_chunk}")
|
print(f"extracted chunk: {extracted_chunk}")
|
||||||
|
@ -837,6 +836,7 @@ def test_openai_chat_completion_call():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
print(f"outside chunk: {chunk}")
|
||||||
if finished:
|
if finished:
|
||||||
break
|
break
|
||||||
complete_response += chunk
|
complete_response += chunk
|
||||||
|
|
|
@ -4549,7 +4549,7 @@ class CustomStreamWrapper:
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback_exception = traceback.print_exc()
|
traceback_exception = traceback.format_exc()
|
||||||
e.message = str(e)
|
e.message = str(e)
|
||||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||||
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
||||||
|
@ -4557,18 +4557,25 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
while True: # loop until a non-empty string is found
|
try:
|
||||||
try:
|
while True:
|
||||||
# if isinstance(self.completion_stream, str):
|
if isinstance(self.completion_stream, str):
|
||||||
# chunk = self.completion_stream
|
chunk = self.completion_stream
|
||||||
# else:
|
else:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
response = self.chunk_creator(chunk=chunk)
|
|
||||||
# if response is not None:
|
if chunk is not None:
|
||||||
return response
|
response = self.chunk_creator(chunk=chunk)
|
||||||
except Exception as e:
|
if response is not None:
|
||||||
raise StopIteration
|
return response
|
||||||
|
except StopIteration:
|
||||||
|
raise # Re-raise StopIteration
|
||||||
|
except Exception as e:
|
||||||
|
# Handle other exceptions if needed
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
if (self.custom_llm_provider == "openai"
|
if (self.custom_llm_provider == "openai"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue