forked from phoenix/litellm-mirror
fix(openai.py): switch to using openai sdk for text completion calls
This commit is contained in:
parent
b07788d2a5
commit
919ec86b2b
2 changed files with 116 additions and 110 deletions
|
@ -1014,6 +1014,8 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
client=None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().completion()
|
super().completion()
|
||||||
|
@ -1024,8 +1026,6 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||||
|
|
||||||
api_base = f"{api_base}/completions"
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(messages) > 0
|
len(messages) > 0
|
||||||
and "content" in messages[0]
|
and "content" in messages[0]
|
||||||
|
@ -1036,9 +1036,9 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||||
|
|
||||||
# don't send max retries to the api, if set
|
# don't send max retries to the api, if set
|
||||||
optional_params.pop("max_retries", None)
|
|
||||||
|
|
||||||
data = {"model": model, "prompt": prompt, **optional_params}
|
data = {"model": model, "prompt": prompt, **optional_params}
|
||||||
|
max_retries = data.pop("max_retries", 2)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -1054,40 +1054,53 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
model=model,
|
model=model,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
|
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
|
||||||
elif optional_params.get("stream", False):
|
elif optional_params.get("stream", False):
|
||||||
return self.streaming(
|
return self.streaming(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
model=model,
|
model=model,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
max_retries=max_retries, # type: ignore
|
||||||
|
client=client,
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = httpx.post(
|
if client is None:
|
||||||
url=f"{api_base}", json=data, headers=headers, timeout=timeout
|
openai_client = OpenAI(
|
||||||
)
|
api_key=api_key,
|
||||||
if response.status_code != 200:
|
base_url=api_base,
|
||||||
raise OpenAIError(
|
http_client=litellm.client_session,
|
||||||
status_code=response.status_code, message=response.text
|
timeout=timeout,
|
||||||
|
max_retries=max_retries, # type: ignore
|
||||||
|
organization=organization,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
response_json = response.json()
|
response = openai_client.completions.create(**data) # type: ignore
|
||||||
|
|
||||||
|
response_json = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=response,
|
original_response=response_json,
|
||||||
additional_args={
|
additional_args={
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
"api_base": api_base,
|
"api_base": api_base,
|
||||||
|
@ -1110,100 +1123,110 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
max_retries=None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
client=None,
|
||||||
):
|
):
|
||||||
|
try:
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
if client is None:
|
||||||
try:
|
openai_aclient = AsyncOpenAI(
|
||||||
response = await client.post(
|
|
||||||
api_base,
|
|
||||||
json=data,
|
|
||||||
headers=headers,
|
|
||||||
timeout=litellm.request_timeout,
|
|
||||||
)
|
|
||||||
response_json = response.json()
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise OpenAIError(
|
|
||||||
status_code=response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=response,
|
base_url=api_base,
|
||||||
additional_args={
|
http_client=litellm.aclient_session,
|
||||||
"headers": headers,
|
timeout=timeout,
|
||||||
"api_base": api_base,
|
max_retries=max_retries,
|
||||||
},
|
organization=organization,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
openai_aclient = client
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
response = await openai_aclient.completions.create(**data)
|
||||||
return TextCompletionResponse(**response_json)
|
response_json = response.model_dump()
|
||||||
except Exception as e:
|
## LOGGING
|
||||||
raise e
|
logging_obj.post_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response,
|
||||||
|
additional_args={
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": api_base,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
return TextCompletionResponse(**response_json)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def streaming(
|
def streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
api_base: str,
|
api_key: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
max_retries=None,
|
||||||
|
client=None,
|
||||||
|
organization=None,
|
||||||
):
|
):
|
||||||
with httpx.stream(
|
if client is None:
|
||||||
url=f"{api_base}",
|
openai_client = OpenAI(
|
||||||
json=data,
|
api_key=api_key,
|
||||||
headers=headers,
|
base_url=api_base,
|
||||||
method="POST",
|
http_client=litellm.client_session,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
) as response:
|
max_retries=max_retries, # type: ignore
|
||||||
if response.status_code != 200:
|
organization=organization,
|
||||||
raise OpenAIError(
|
|
||||||
status_code=response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
|
||||||
completion_stream=response.iter_lines(),
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="text-completion-openai",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
)
|
||||||
for transformed_chunk in streamwrapper:
|
else:
|
||||||
yield transformed_chunk
|
openai_client = client
|
||||||
|
response = openai_client.completions.create(**data)
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=response,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="text-completion-openai",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in streamwrapper:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
api_base: str,
|
api_key: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
client=None,
|
||||||
|
max_retries=None,
|
||||||
|
organization=None,
|
||||||
):
|
):
|
||||||
client = httpx.AsyncClient()
|
if client is None:
|
||||||
async with client.stream(
|
openai_client = AsyncOpenAI(
|
||||||
url=f"{api_base}",
|
api_key=api_key,
|
||||||
json=data,
|
base_url=api_base,
|
||||||
headers=headers,
|
http_client=litellm.aclient_session,
|
||||||
method="POST",
|
timeout=timeout,
|
||||||
timeout=timeout,
|
max_retries=max_retries,
|
||||||
) as response:
|
organization=organization,
|
||||||
try:
|
)
|
||||||
if response.status_code != 200:
|
else:
|
||||||
raise OpenAIError(
|
openai_client = client
|
||||||
status_code=response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
response = await openai_client.completions.create(**data)
|
||||||
completion_stream=response.aiter_lines(),
|
|
||||||
model=model,
|
streamwrapper = CustomStreamWrapper(
|
||||||
custom_llm_provider="text-completion-openai",
|
completion_stream=response,
|
||||||
logging_obj=logging_obj,
|
model=model,
|
||||||
)
|
custom_llm_provider="text-completion-openai",
|
||||||
async for transformed_chunk in streamwrapper:
|
logging_obj=logging_obj,
|
||||||
yield transformed_chunk
|
)
|
||||||
except Exception as e:
|
|
||||||
raise e
|
async for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
|
|
|
@ -9005,37 +9005,20 @@ class CustomStreamWrapper:
|
||||||
def handle_openai_text_completion_chunk(self, chunk):
|
def handle_openai_text_completion_chunk(self, chunk):
|
||||||
try:
|
try:
|
||||||
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
|
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
|
||||||
str_line = chunk
|
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
if "data: [DONE]" in str_line or self.sent_last_chunk == True:
|
choices = getattr(chunk, "choices", [])
|
||||||
raise StopIteration
|
if len(choices) > 0:
|
||||||
elif str_line.startswith("data:"):
|
text = choices[0].text
|
||||||
data_json = json.loads(str_line[5:])
|
if choices[0].finish_reason is not None:
|
||||||
print_verbose(f"delta content: {data_json}")
|
|
||||||
text = data_json["choices"][0].get("text", "")
|
|
||||||
if data_json["choices"][0].get("finish_reason", None):
|
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = data_json["choices"][0]["finish_reason"]
|
finish_reason = choices[0].finish_reason
|
||||||
print_verbose(
|
return {
|
||||||
f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}"
|
"text": text,
|
||||||
)
|
"is_finished": is_finished,
|
||||||
return {
|
"finish_reason": finish_reason,
|
||||||
"text": text,
|
}
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
elif "error" in str_line:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unable to parse response. Original response: {str_line}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue