mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(main.py): cover openai /v1/completions endpoint
This commit is contained in:
parent
a6c38e8bff
commit
79bfdb83cc
4 changed files with 67 additions and 26 deletions
|
@ -1268,6 +1268,8 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
except (
|
except (
|
||||||
Exception
|
Exception
|
||||||
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
||||||
|
if isinstance(e, OpenAIError):
|
||||||
|
raise e
|
||||||
if response is not None and hasattr(response, "text"):
|
if response is not None and hasattr(response, "text"):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -1975,7 +1977,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
@ -2019,7 +2021,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
|
|
||||||
response = openai_client.completions.create(**data) # type: ignore
|
response = openai_client.completions.with_raw_response.create(**data) # type: ignore
|
||||||
|
|
||||||
response_json = response.model_dump()
|
response_json = response.model_dump()
|
||||||
|
|
||||||
|
@ -2067,7 +2069,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
|
|
||||||
response = await openai_aclient.completions.create(**data)
|
response = await openai_aclient.completions.with_raw_response.create(**data)
|
||||||
response_json = response.model_dump()
|
response_json = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -2100,6 +2102,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
organization=None,
|
organization=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
openai_client = OpenAI(
|
openai_client = OpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -2111,7 +2114,15 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
response = openai_client.completions.create(**data)
|
|
||||||
|
try:
|
||||||
|
response = openai_client.completions.with_raw_response.create(**data)
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, "status_code", 500)
|
||||||
|
error_headers = getattr(e, "headers", None)
|
||||||
|
raise OpenAIError(
|
||||||
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2149,7 +2160,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
|
|
||||||
response = await openai_client.completions.create(**data)
|
response = await openai_client.completions.with_raw_response.create(**data)
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
|
|
|
@ -445,7 +445,12 @@ async def _async_streaming(response, model, custom_llm_provider, args):
|
||||||
print_verbose(f"line in async streaming: {line}")
|
print_verbose(f"line in async streaming: {line}")
|
||||||
yield line
|
yield line
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
custom_llm_provider = custom_llm_provider or "openai"
|
||||||
|
raise exception_type(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mock_completion(
|
def mock_completion(
|
||||||
|
@ -3736,7 +3741,7 @@ async def atext_completion(
|
||||||
else:
|
else:
|
||||||
# Call the synchronous function using run_in_executor
|
# Call the synchronous function using run_in_executor
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
if kwargs.get("stream", False) == True: # return an async generator
|
if kwargs.get("stream", False) is True: # return an async generator
|
||||||
return TextCompletionStreamWrapper(
|
return TextCompletionStreamWrapper(
|
||||||
completion_stream=_async_streaming(
|
completion_stream=_async_streaming(
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -3745,6 +3750,7 @@ async def atext_completion(
|
||||||
args=args,
|
args=args,
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
|
@ -4018,11 +4024,14 @@ def text_completion(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
if kwargs.get("acompletion", False) == True:
|
if kwargs.get("acompletion", False) is True:
|
||||||
return response
|
return response
|
||||||
if stream == True or kwargs.get("stream", False) == True:
|
if stream is True or kwargs.get("stream", False) is True:
|
||||||
response = TextCompletionStreamWrapper(
|
response = TextCompletionStreamWrapper(
|
||||||
completion_stream=response, model=model, stream_options=stream_options
|
completion_stream=response,
|
||||||
|
model=model,
|
||||||
|
stream_options=stream_options,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
|
|
|
@ -869,6 +869,15 @@ def _pre_call_utils(
|
||||||
original_function = litellm.completion
|
original_function = litellm.completion
|
||||||
else:
|
else:
|
||||||
original_function = litellm.acompletion
|
original_function = litellm.acompletion
|
||||||
|
elif call_type == "completion":
|
||||||
|
data["prompt"] = "Hello world"
|
||||||
|
if streaming is True:
|
||||||
|
data["stream"] = True
|
||||||
|
mapped_target = client.completions.with_raw_response
|
||||||
|
if sync_mode:
|
||||||
|
original_function = litellm.text_completion
|
||||||
|
else:
|
||||||
|
original_function = litellm.atext_completion
|
||||||
|
|
||||||
return data, original_function, mapped_target
|
return data, original_function, mapped_target
|
||||||
|
|
||||||
|
@ -883,6 +892,7 @@ def _pre_call_utils(
|
||||||
("text-embedding-ada-002", "embedding", None),
|
("text-embedding-ada-002", "embedding", None),
|
||||||
("gpt-3.5-turbo", "chat_completion", False),
|
("gpt-3.5-turbo", "chat_completion", False),
|
||||||
("gpt-3.5-turbo", "chat_completion", True),
|
("gpt-3.5-turbo", "chat_completion", True),
|
||||||
|
("gpt-3.5-turbo-instruct", "completion", True),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -933,27 +943,25 @@ async def test_exception_with_headers(sync_mode, model, call_type, streaming):
|
||||||
new_retry_after_mock_client
|
new_retry_after_mock_client
|
||||||
)
|
)
|
||||||
|
|
||||||
|
exception_raised = False
|
||||||
try:
|
try:
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
resp = original_function(
|
resp = original_function(**data, client=openai_client)
|
||||||
model="text-embedding-ada-002",
|
|
||||||
input="Hello world!",
|
|
||||||
client=openai_client,
|
|
||||||
)
|
|
||||||
if streaming:
|
if streaming:
|
||||||
for chunk in resp:
|
for chunk in resp:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
resp = await original_function(
|
resp = await original_function(**data, client=openai_client)
|
||||||
model="text-embedding-ada-002",
|
|
||||||
input="Hello world!",
|
|
||||||
client=openai_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
async for chunk in resp:
|
async for chunk in resp:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
|
exception_raised = True
|
||||||
assert e.litellm_response_headers is not None
|
assert e.litellm_response_headers is not None
|
||||||
assert e.litellm_response_headers["retry-after"] == cooldown_time
|
assert e.litellm_response_headers["retry-after"] == cooldown_time
|
||||||
|
|
||||||
|
if exception_raised is False:
|
||||||
|
print(resp)
|
||||||
|
assert exception_raised
|
||||||
|
|
|
@ -6833,7 +6833,7 @@ def exception_type(
|
||||||
message=f"{exception_provider} - {message}",
|
message=f"{exception_provider} - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
|
@ -6842,7 +6842,7 @@ def exception_type(
|
||||||
message=f"RateLimitError: {exception_provider} - {message}",
|
message=f"RateLimitError: {exception_provider} - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 503:
|
elif original_exception.status_code == 503:
|
||||||
|
@ -6851,7 +6851,7 @@ def exception_type(
|
||||||
message=f"ServiceUnavailableError: {exception_provider} - {message}",
|
message=f"ServiceUnavailableError: {exception_provider} - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 504: # gateway timeout error
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
|
@ -6869,7 +6869,7 @@ def exception_type(
|
||||||
message=f"APIError: {exception_provider} - {message}",
|
message=f"APIError: {exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request,
|
request=getattr(original_exception, "request", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -10882,10 +10882,17 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
|
|
||||||
class TextCompletionStreamWrapper:
|
class TextCompletionStreamWrapper:
|
||||||
def __init__(self, completion_stream, model, stream_options: Optional[dict] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
completion_stream,
|
||||||
|
model,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
):
|
||||||
self.completion_stream = completion_stream
|
self.completion_stream = completion_stream
|
||||||
self.model = model
|
self.model = model
|
||||||
self.stream_options = stream_options
|
self.stream_options = stream_options
|
||||||
|
self.custom_llm_provider = custom_llm_provider
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -10936,7 +10943,13 @@ class TextCompletionStreamWrapper:
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"got exception {e}") # noqa
|
raise exception_type(
|
||||||
|
model=self.model,
|
||||||
|
custom_llm_provider=self.custom_llm_provider or "",
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs={},
|
||||||
|
extra_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue