fix(openai.py): switch to using openai sdk for text completion calls

This commit is contained in:
Krrish Dholakia 2024-04-02 15:08:12 -07:00
parent b07788d2a5
commit 919ec86b2b
2 changed files with 116 additions and 110 deletions

View file

@ -1014,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
optional_params=None,
litellm_params=None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
):
super().completion()
@ -1024,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
api_base = f"{api_base}/completions"
if (
len(messages) > 0
and "content" in messages[0]
@ -1036,9 +1036,9 @@ class OpenAITextCompletion(BaseLLM):
prompt = " ".join([message["content"] for message in messages]) # type: ignore
# don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
## LOGGING
logging_obj.pre_call(
input=messages,
@ -1054,40 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
)
else:
response = httpx.post(
url=f"{api_base}", json=data, headers=headers, timeout=timeout
)
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
response_json = response.json()
response = openai_client.completions.create(**data) # type: ignore
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
@ -1110,100 +1123,110 @@ class OpenAITextCompletion(BaseLLM):
api_key: str,
model: str,
timeout: float,
max_retries=None,
organization: Optional[str] = None,
client=None,
):
async with httpx.AsyncClient(timeout=timeout) as client:
try:
response = await client.post(
api_base,
json=data,
headers=headers,
timeout=litellm.request_timeout,
)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
raise e
response = await openai_aclient.completions.create(**data)
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
raise e
def streaming(
self,
logging_obj,
api_base: str,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
):
with httpx.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
timeout=timeout,
) as response:
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
for transformed_chunk in streamwrapper:
yield transformed_chunk
else:
openai_client = client
response = openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
for chunk in streamwrapper:
yield chunk
async def async_streaming(
self,
logging_obj,
api_base: str,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
):
client = httpx.AsyncClient()
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
timeout=timeout,
) as response:
try:
if response.status_code != 200:
raise OpenAIError(
status_code=response.status_code, message=response.text
)
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
streamwrapper = CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e
response = await openai_client.completions.create(**data)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk