fix(openai.py): switch to using openai sdk for text completion calls

This commit is contained in:
Krrish Dholakia 2024-04-02 15:08:12 -07:00
parent b07788d2a5
commit 919ec86b2b
2 changed files with 116 additions and 110 deletions

View file

@ -1014,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
): ):
super().completion() super().completion()
@ -1024,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
if model is None or messages is None: if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
api_base = f"{api_base}/completions"
if ( if (
len(messages) > 0 len(messages) > 0
and "content" in messages[0] and "content" in messages[0]
@ -1036,9 +1036,9 @@ class OpenAITextCompletion(BaseLLM):
prompt = " ".join([message["content"] for message in messages]) # type: ignore prompt = " ".join([message["content"] for message in messages]) # type: ignore
# don't send max retries to the api, if set # don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = {"model": model, "prompt": prompt, **optional_params} data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
@ -1054,40 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
return self.async_streaming( return self.async_streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,
api_key=api_key,
data=data, data=data,
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout, timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
) )
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming( return self.streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,
api_key=api_key,
data=data, data=data,
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout, timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
) )
else: else:
response = httpx.post( if client is None:
url=f"{api_base}", json=data, headers=headers, timeout=timeout openai_client = OpenAI(
) api_key=api_key,
if response.status_code != 200: base_url=api_base,
raise OpenAIError( http_client=litellm.client_session,
status_code=response.status_code, message=response.text timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
) )
else:
openai_client = client
response_json = response.json() response = openai_client.completions.create(**data) # type: ignore
response_json = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response, original_response=response_json,
additional_args={ additional_args={
"headers": headers, "headers": headers,
"api_base": api_base, "api_base": api_base,
@ -1110,22 +1123,25 @@ class OpenAITextCompletion(BaseLLM):
api_key: str, api_key: str,
model: str, model: str,
timeout: float, timeout: float,
max_retries=None,
organization: Optional[str] = None,
client=None,
): ):
async with httpx.AsyncClient(timeout=timeout) as client:
try: try:
response = await client.post( if client is None:
api_base, openai_aclient = AsyncOpenAI(
json=data, api_key=api_key,
headers=headers, base_url=api_base,
timeout=litellm.request_timeout, http_client=litellm.aclient_session,
) timeout=timeout,
response_json = response.json() max_retries=max_retries,
if response.status_code != 200: organization=organization,
raise OpenAIError(
status_code=response.status_code, message=response.text
) )
else:
openai_aclient = client
response = await openai_aclient.completions.create(**data)
response_json = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
@ -1136,7 +1152,6 @@ class OpenAITextCompletion(BaseLLM):
"api_base": api_base, "api_base": api_base,
}, },
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
return TextCompletionResponse(**response_json) return TextCompletionResponse(**response_json)
except Exception as e: except Exception as e:
@ -1145,65 +1160,73 @@ class OpenAITextCompletion(BaseLLM):
def streaming( def streaming(
self, self,
logging_obj, logging_obj,
api_base: str, api_key: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float, timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
): ):
with httpx.stream( if client is None:
url=f"{api_base}", openai_client = OpenAI(
json=data, api_key=api_key,
headers=headers, base_url=api_base,
method="POST", http_client=litellm.client_session,
timeout=timeout, timeout=timeout,
) as response: max_retries=max_retries, # type: ignore
if response.status_code != 200: organization=organization,
raise OpenAIError(
status_code=response.status_code, message=response.text
) )
else:
openai_client = client
response = openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response.iter_lines(), completion_stream=response,
model=model, model=model,
custom_llm_provider="text-completion-openai", custom_llm_provider="text-completion-openai",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
for transformed_chunk in streamwrapper:
yield transformed_chunk for chunk in streamwrapper:
yield chunk
async def async_streaming( async def async_streaming(
self, self,
logging_obj, logging_obj,
api_base: str, api_key: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float, timeout: float,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
): ):
client = httpx.AsyncClient() if client is None:
async with client.stream( openai_client = AsyncOpenAI(
url=f"{api_base}", api_key=api_key,
json=data, base_url=api_base,
headers=headers, http_client=litellm.aclient_session,
method="POST",
timeout=timeout, timeout=timeout,
) as response: max_retries=max_retries,
try: organization=organization,
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
) )
else:
openai_client = client
response = await openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response.aiter_lines(), completion_stream=response,
model=model, model=model,
custom_llm_provider="text-completion-openai", custom_llm_provider="text-completion-openai",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e:
raise e

View file

@ -9005,32 +9005,15 @@ class CustomStreamWrapper:
def handle_openai_text_completion_chunk(self, chunk): def handle_openai_text_completion_chunk(self, chunk):
try: try:
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n") print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
str_line = chunk
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None
if "data: [DONE]" in str_line or self.sent_last_chunk == True: choices = getattr(chunk, "choices", [])
raise StopIteration if len(choices) > 0:
elif str_line.startswith("data:"): text = choices[0].text
data_json = json.loads(str_line[5:]) if choices[0].finish_reason is not None:
print_verbose(f"delta content: {data_json}")
text = data_json["choices"][0].get("text", "")
if data_json["choices"][0].get("finish_reason", None):
is_finished = True is_finished = True
finish_reason = data_json["choices"][0]["finish_reason"] finish_reason = choices[0].finish_reason
print_verbose(
f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}"
)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "error" in str_line:
raise ValueError(
f"Unable to parse response. Original response: {str_line}"
)
else:
return { return {
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,