refactor(openai.py): moving openai text completion calls to http

This commit is contained in:
Krrish Dholakia 2023-11-08 18:39:56 -08:00
parent 901b0e690e
commit e66373bd47
6 changed files with 211 additions and 66 deletions

View file

@ -49,7 +49,7 @@ from .llms import (
palm,
vertex_ai,
maritalk)
from .llms.openai import OpenAIChatCompletion
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
import tiktoken
@ -73,6 +73,7 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion()
####### COMPLETION ENDPOINTS ################
@ -498,14 +499,8 @@ def completion(
)
elif (
custom_llm_provider == "text-completion-openai"
or model in litellm.open_ai_text_completion_models
or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models
# NOTE: Do NOT add custom_llm_provider == "openai".
# this will break hosted vllm/proxy calls.
# see: https://docs.litellm.ai/docs/providers/vllm#calling-hosted-vllm-server.
# VLLM expects requests to call openai.ChatCompletion we need those requests to always
# call openai.ChatCompletion
):
# print("calling custom openai provider")
openai.api_type = "openai"
@ -558,43 +553,22 @@ def completion(
},
)
## COMPLETION CALL
response = openai.Completion.create(
model=model,
prompt=prompt,
headers=headers,
api_key = api_key,
api_base=api_base,
**optional_params
)
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
return response
## LOGGING
logging.post_call(
input=prompt,
model_response = openai_text_completions.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
original_response=response,
additional_args={
"openai_organization": litellm.organization,
"headers": headers,
"api_base": openai.api_base,
"api_type": openai.api_type,
},
api_base=api_base,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn
)
## RESPONSE OBJECT
model_response._hidden_params["original_response"] = response # track original response, if users make a litellm.text_completion() request, we can return the original response
choices_list = []
for idx, item in enumerate(response["choices"]):
if len(item["text"]) > 0:
message_obj = Message(content=item["text"])
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
model_response["created"] = response.get("created", time.time())
model_response["model"] = model
model_response["usage"] = response.get("usage", 0)
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
return response
response = model_response
elif (
"replicate" in model or