fix(openai.py): switch to using openai sdk for text completion calls

This commit is contained in:
Krrish Dholakia 2024-04-02 15:08:12 -07:00
parent b07788d2a5
commit 919ec86b2b
2 changed files with 116 additions and 110 deletions

View file

@ -1014,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
): ):
super().completion() super().completion()
@ -1024,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
if model is None or messages is None: if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
api_base = f"{api_base}/completions"
if ( if (
len(messages) > 0 len(messages) > 0
and "content" in messages[0] and "content" in messages[0]
@ -1036,9 +1036,9 @@ class OpenAITextCompletion(BaseLLM):
prompt = " ".join([message["content"] for message in messages]) # type: ignore prompt = " ".join([message["content"] for message in messages]) # type: ignore
# don't send max retries to the api, if set # don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = {"model": model, "prompt": prompt, **optional_params} data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
@ -1054,40 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
return self.async_streaming( return self.async_streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,
api_key=api_key,
data=data, data=data,
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout, timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
) )
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming( return self.streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,
api_key=api_key,
data=data, data=data,
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout, timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
) )
else: else:
response = httpx.post( if client is None:
url=f"{api_base}", json=data, headers=headers, timeout=timeout openai_client = OpenAI(
) api_key=api_key,
if response.status_code != 200: base_url=api_base,
raise OpenAIError( http_client=litellm.client_session,
status_code=response.status_code, message=response.text timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
) )
else:
openai_client = client
response_json = response.json() response = openai_client.completions.create(**data) # type: ignore
response_json = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response, original_response=response_json,
additional_args={ additional_args={
"headers": headers, "headers": headers,
"api_base": api_base, "api_base": api_base,
@ -1110,100 +1123,110 @@ class OpenAITextCompletion(BaseLLM):
api_key: str, api_key: str,
model: str, model: str,
timeout: float, timeout: float,
max_retries=None,
organization: Optional[str] = None,
client=None,
): ):
try:
async with httpx.AsyncClient(timeout=timeout) as client: if client is None:
try: openai_aclient = AsyncOpenAI(
response = await client.post(
api_base,
json=data,
headers=headers,
timeout=litellm.request_timeout,
)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key, api_key=api_key,
original_response=response, base_url=api_base,
additional_args={ http_client=litellm.aclient_session,
"headers": headers, timeout=timeout,
"api_base": api_base, max_retries=max_retries,
}, organization=organization,
) )
else:
openai_aclient = client
## RESPONSE OBJECT response = await openai_aclient.completions.create(**data)
return TextCompletionResponse(**response_json) response_json = response.model_dump()
except Exception as e: ## LOGGING
raise e logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
raise e
def streaming( def streaming(
self, self,
logging_obj, logging_obj,
api_base: str, api_key: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float, timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
): ):
with httpx.stream( if client is None:
url=f"{api_base}", openai_client = OpenAI(
json=data, api_key=api_key,
headers=headers, base_url=api_base,
method="POST", http_client=litellm.client_session,
timeout=timeout, timeout=timeout,
) as response: max_retries=max_retries, # type: ignore
if response.status_code != 200: organization=organization,
raise OpenAIError(
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
) )
for transformed_chunk in streamwrapper: else:
yield transformed_chunk openai_client = client
response = openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
for chunk in streamwrapper:
yield chunk
async def async_streaming( async def async_streaming(
self, self,
logging_obj, logging_obj,
api_base: str, api_key: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float, timeout: float,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
): ):
client = httpx.AsyncClient() if client is None:
async with client.stream( openai_client = AsyncOpenAI(
url=f"{api_base}", api_key=api_key,
json=data, base_url=api_base,
headers=headers, http_client=litellm.aclient_session,
method="POST", timeout=timeout,
timeout=timeout, max_retries=max_retries,
) as response: organization=organization,
try: )
if response.status_code != 200: else:
raise OpenAIError( openai_client = client
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper( response = await openai_client.completions.create(**data)
completion_stream=response.aiter_lines(),
model=model, streamwrapper = CustomStreamWrapper(
custom_llm_provider="text-completion-openai", completion_stream=response,
logging_obj=logging_obj, model=model,
) custom_llm_provider="text-completion-openai",
async for transformed_chunk in streamwrapper: logging_obj=logging_obj,
yield transformed_chunk )
except Exception as e:
raise e async for transformed_chunk in streamwrapper:
yield transformed_chunk

View file

@ -9005,37 +9005,20 @@ class CustomStreamWrapper:
def handle_openai_text_completion_chunk(self, chunk): def handle_openai_text_completion_chunk(self, chunk):
try: try:
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n") print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
str_line = chunk
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None
if "data: [DONE]" in str_line or self.sent_last_chunk == True: choices = getattr(chunk, "choices", [])
raise StopIteration if len(choices) > 0:
elif str_line.startswith("data:"): text = choices[0].text
data_json = json.loads(str_line[5:]) if choices[0].finish_reason is not None:
print_verbose(f"delta content: {data_json}")
text = data_json["choices"][0].get("text", "")
if data_json["choices"][0].get("finish_reason", None):
is_finished = True is_finished = True
finish_reason = data_json["choices"][0]["finish_reason"] finish_reason = choices[0].finish_reason
print_verbose( return {
f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}" "text": text,
) "is_finished": is_finished,
return { "finish_reason": finish_reason,
"text": text, }
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "error" in str_line:
raise ValueError(
f"Unable to parse response. Original response: {str_line}"
)
else:
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception as e: except Exception as e:
raise e raise e