(feat) text_com support batches for non openai llms

This commit is contained in:
ishaan-jaff 2023-11-03 16:34:34 -07:00
parent b45d438e63
commit 0fa7c1ec3a

View file

@ -1861,34 +1861,71 @@ def embedding(
###### Text Completion ################
def text_completion(*args, **kwargs):
import copy
"""
This maps to the Openai.Completion.create format, which has a different I/O (accepts prompt, returning ["choices"]["text"].
"""
if "engine" in kwargs:
kwargs["model"] = kwargs["engine"]
kwargs.pop("engine")
# if echo == True, for TGI llms we need to set top_n_tokens to 3
if kwargs.get("echo", False) == True and (kwargs["model"] not in litellm.open_ai_text_completion_models):
# for tgi llms
if "top_n_tokens" not in kwargs:
kwargs["top_n_tokens"] = 3
if "prompt" in kwargs:
if type(kwargs["prompt"]) == list:
new_prompt = ""
for chunk in kwargs["prompt"]:
decoded = litellm.utils.decode(model="gpt2", tokens=chunk)
new_prompt+= decoded
kwargs["prompt"] = new_prompt
# input validation
if "prompt" not in kwargs:
raise ValueError("please pass prompt into the `text_completion` endpoint - `text_completion(model, prompt='hello world')`")
model = kwargs["model"]
prompt = kwargs["prompt"]
# get custom_llm_provider
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model)
if custom_llm_provider == "text-completion-openai":
# text-davinci-003 and openai text completion models
messages = [{"role": "system", "content": kwargs["prompt"]}]
kwargs["messages"] = messages
kwargs.pop("prompt")
response = completion(*args, **kwargs) # assume the response is the openai response object
# return raw response from openai
return response._hidden_params.get("original_response", None)
elif custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3
if kwargs.get("echo", False) == True:
# for tgi llms
if "top_n_tokens" not in kwargs:
kwargs["top_n_tokens"] = 3
# processing prompt - users can pass raw tokens to OpenAI Completion()
if type(prompt) == list:
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
## if it's a 2d list - each element in the list is a text_completion() request
if len(prompt) > 0 and type(prompt[0]) == list:
responses = [None for x in prompt] # init responses
for i, request in enumerate(prompt):
decoded_prompt = tokenizer.decode(request)
# print("\ndecoded\n", decoded_prompt)
# print("type decoded", type(decoded_prompt))
new_kwargs = copy.deepcopy(kwargs)
new_kwargs["prompt"] = decoded_prompt
# print("making new individual request", new_kwargs)
response = text_completion(**new_kwargs)
# print("assigning for ", i)
responses[i] = response["choices"][0]
print(responses)
formatted_response_obj = {
"id": response["id"],
"object": "text_completion",
"created": response["created"],
"model": response["model"],
"choices": responses,
"usage": response["usage"]
}
return formatted_response_obj
else:
messages = [{"role": "system", "content": kwargs["prompt"]}]
kwargs["messages"] = messages
kwargs.pop("prompt")
response = completion(*args, **kwargs) # assume the response is the openai response object
# if the model is text-davinci-003, return raw response from openai
if kwargs["model"] in litellm.open_ai_text_completion_models and response._hidden_params.get("original_response", None) != None:
return response._hidden_params.get("original_response", None)
transformed_logprobs = None
# only supported for TGI models
try:
@ -1912,8 +1949,6 @@ def text_completion(*args, **kwargs):
"usage": response["usage"]
}
return formatted_response_obj
else:
raise ValueError("please pass prompt into the `text_completion` endpoint - `text_completion(model, prompt='hello world')`")
##### Moderation #######################
def moderation(input: str, api_key: Optional[str]=None):