fix(main.py): cover openai /v1/completions endpoint

This commit is contained in:
Krrish Dholakia 2024-08-24 13:25:17 -07:00
parent de2373d52b
commit 87549a2391
4 changed files with 67 additions and 26 deletions

View file

@ -1268,6 +1268,8 @@ class OpenAIChatCompletion(BaseLLM):
except ( except (
Exception Exception
) as e: # need to exception handle here. async exceptions don't get caught in sync functions. ) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if isinstance(e, OpenAIError):
raise e
if response is not None and hasattr(response, "text"): if response is not None and hasattr(response, "text"):
raise OpenAIError( raise OpenAIError(
status_code=500, status_code=500,
@ -1975,7 +1977,7 @@ class OpenAITextCompletion(BaseLLM):
"complete_input_dict": data, "complete_input_dict": data,
}, },
) )
if acompletion == True: if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming( return self.async_streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
@ -2019,7 +2021,7 @@ class OpenAITextCompletion(BaseLLM):
else: else:
openai_client = client openai_client = client
response = openai_client.completions.create(**data) # type: ignore response = openai_client.completions.with_raw_response.create(**data) # type: ignore
response_json = response.model_dump() response_json = response.model_dump()
@ -2067,7 +2069,7 @@ class OpenAITextCompletion(BaseLLM):
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.completions.create(**data) response = await openai_aclient.completions.with_raw_response.create(**data)
response_json = response.model_dump() response_json = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -2100,6 +2102,7 @@ class OpenAITextCompletion(BaseLLM):
client=None, client=None,
organization=None, organization=None,
): ):
if client is None: if client is None:
openai_client = OpenAI( openai_client = OpenAI(
api_key=api_key, api_key=api_key,
@ -2111,7 +2114,15 @@ class OpenAITextCompletion(BaseLLM):
) )
else: else:
openai_client = client openai_client = client
response = openai_client.completions.create(**data)
try:
response = openai_client.completions.with_raw_response.create(**data)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response, completion_stream=response,
model=model, model=model,
@ -2149,7 +2160,7 @@ class OpenAITextCompletion(BaseLLM):
else: else:
openai_client = client openai_client = client
response = await openai_client.completions.create(**data) response = await openai_client.completions.with_raw_response.create(**data)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response, completion_stream=response,

View file

@ -445,7 +445,12 @@ async def _async_streaming(response, model, custom_llm_provider, args):
print_verbose(f"line in async streaming: {line}") print_verbose(f"line in async streaming: {line}")
yield line yield line
except Exception as e: except Exception as e:
raise e custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
)
def mock_completion( def mock_completion(
@ -3736,7 +3741,7 @@ async def atext_completion(
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False) == True: # return an async generator if kwargs.get("stream", False) is True: # return an async generator
return TextCompletionStreamWrapper( return TextCompletionStreamWrapper(
completion_stream=_async_streaming( completion_stream=_async_streaming(
response=response, response=response,
@ -3745,6 +3750,7 @@ async def atext_completion(
args=args, args=args,
), ),
model=model, model=model,
custom_llm_provider=custom_llm_provider,
) )
else: else:
transformed_logprobs = None transformed_logprobs = None
@ -4018,11 +4024,14 @@ def text_completion(
**kwargs, **kwargs,
**optional_params, **optional_params,
) )
if kwargs.get("acompletion", False) == True: if kwargs.get("acompletion", False) is True:
return response return response
if stream == True or kwargs.get("stream", False) == True: if stream is True or kwargs.get("stream", False) is True:
response = TextCompletionStreamWrapper( response = TextCompletionStreamWrapper(
completion_stream=response, model=model, stream_options=stream_options completion_stream=response,
model=model,
stream_options=stream_options,
custom_llm_provider=custom_llm_provider,
) )
return response return response
transformed_logprobs = None transformed_logprobs = None

View file

@ -869,6 +869,15 @@ def _pre_call_utils(
original_function = litellm.completion original_function = litellm.completion
else: else:
original_function = litellm.acompletion original_function = litellm.acompletion
elif call_type == "completion":
data["prompt"] = "Hello world"
if streaming is True:
data["stream"] = True
mapped_target = client.completions.with_raw_response
if sync_mode:
original_function = litellm.text_completion
else:
original_function = litellm.atext_completion
return data, original_function, mapped_target return data, original_function, mapped_target
@ -883,6 +892,7 @@ def _pre_call_utils(
("text-embedding-ada-002", "embedding", None), ("text-embedding-ada-002", "embedding", None),
("gpt-3.5-turbo", "chat_completion", False), ("gpt-3.5-turbo", "chat_completion", False),
("gpt-3.5-turbo", "chat_completion", True), ("gpt-3.5-turbo", "chat_completion", True),
("gpt-3.5-turbo-instruct", "completion", True),
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -933,27 +943,25 @@ async def test_exception_with_headers(sync_mode, model, call_type, streaming):
new_retry_after_mock_client new_retry_after_mock_client
) )
exception_raised = False
try: try:
if sync_mode: if sync_mode:
resp = original_function( resp = original_function(**data, client=openai_client)
model="text-embedding-ada-002",
input="Hello world!",
client=openai_client,
)
if streaming: if streaming:
for chunk in resp: for chunk in resp:
continue continue
else: else:
resp = await original_function( resp = await original_function(**data, client=openai_client)
model="text-embedding-ada-002",
input="Hello world!",
client=openai_client,
)
if streaming: if streaming:
async for chunk in resp: async for chunk in resp:
continue continue
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
exception_raised = True
assert e.litellm_response_headers is not None assert e.litellm_response_headers is not None
assert e.litellm_response_headers["retry-after"] == cooldown_time assert e.litellm_response_headers["retry-after"] == cooldown_time
if exception_raised is False:
print(resp)
assert exception_raised

View file

@ -6833,7 +6833,7 @@ def exception_type(
message=f"{exception_provider} - {message}", message=f"{exception_provider} - {message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
@ -6842,7 +6842,7 @@ def exception_type(
message=f"RateLimitError: {exception_provider} - {message}", message=f"RateLimitError: {exception_provider} - {message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
@ -6851,7 +6851,7 @@ def exception_type(
message=f"ServiceUnavailableError: {exception_provider} - {message}", message=f"ServiceUnavailableError: {exception_provider} - {message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error
@ -6869,7 +6869,7 @@ def exception_type(
message=f"APIError: {exception_provider} - {message}", message=f"APIError: {exception_provider} - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
request=original_exception.request, request=getattr(original_exception, "request", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
else: else:
@ -10882,10 +10882,17 @@ class CustomStreamWrapper:
class TextCompletionStreamWrapper: class TextCompletionStreamWrapper:
def __init__(self, completion_stream, model, stream_options: Optional[dict] = None): def __init__(
self,
completion_stream,
model,
stream_options: Optional[dict] = None,
custom_llm_provider: Optional[str] = None,
):
self.completion_stream = completion_stream self.completion_stream = completion_stream
self.model = model self.model = model
self.stream_options = stream_options self.stream_options = stream_options
self.custom_llm_provider = custom_llm_provider
def __iter__(self): def __iter__(self):
return self return self
@ -10936,7 +10943,13 @@ class TextCompletionStreamWrapper:
except StopIteration: except StopIteration:
raise StopIteration raise StopIteration
except Exception as e: except Exception as e:
print(f"got exception {e}") # noqa raise exception_type(
model=self.model,
custom_llm_provider=self.custom_llm_provider or "",
original_exception=e,
completion_kwargs={},
extra_kwargs={},
)
async def __anext__(self): async def __anext__(self):
try: try: