mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(feat) text_com support batches for non openai llms
This commit is contained in:
parent
b45d438e63
commit
0fa7c1ec3a
1 changed files with 53 additions and 18 deletions
|
@ -1861,34 +1861,71 @@ def embedding(
|
|||
|
||||
###### Text Completion ################
|
||||
def text_completion(*args, **kwargs):
|
||||
import copy
|
||||
"""
|
||||
This maps to the Openai.Completion.create format, which has a different I/O (accepts prompt, returning ["choices"]["text"].
|
||||
"""
|
||||
if "engine" in kwargs:
|
||||
kwargs["model"] = kwargs["engine"]
|
||||
kwargs.pop("engine")
|
||||
|
||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||
if kwargs.get("echo", False) == True and (kwargs["model"] not in litellm.open_ai_text_completion_models):
|
||||
# for tgi llms
|
||||
if "top_n_tokens" not in kwargs:
|
||||
kwargs["top_n_tokens"] = 3
|
||||
if "prompt" in kwargs:
|
||||
if type(kwargs["prompt"]) == list:
|
||||
new_prompt = ""
|
||||
for chunk in kwargs["prompt"]:
|
||||
decoded = litellm.utils.decode(model="gpt2", tokens=chunk)
|
||||
new_prompt+= decoded
|
||||
kwargs["prompt"] = new_prompt
|
||||
|
||||
# input validation
|
||||
if "prompt" not in kwargs:
|
||||
raise ValueError("please pass prompt into the `text_completion` endpoint - `text_completion(model, prompt='hello world')`")
|
||||
|
||||
model = kwargs["model"]
|
||||
prompt = kwargs["prompt"]
|
||||
# get custom_llm_provider
|
||||
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model)
|
||||
|
||||
if custom_llm_provider == "text-completion-openai":
|
||||
# text-davinci-003 and openai text completion models
|
||||
messages = [{"role": "system", "content": kwargs["prompt"]}]
|
||||
kwargs["messages"] = messages
|
||||
kwargs.pop("prompt")
|
||||
response = completion(*args, **kwargs) # assume the response is the openai response object
|
||||
# return raw response from openai
|
||||
return response._hidden_params.get("original_response", None)
|
||||
|
||||
elif custom_llm_provider == "huggingface":
|
||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||
if kwargs.get("echo", False) == True:
|
||||
# for tgi llms
|
||||
if "top_n_tokens" not in kwargs:
|
||||
kwargs["top_n_tokens"] = 3
|
||||
|
||||
# processing prompt - users can pass raw tokens to OpenAI Completion()
|
||||
if type(prompt) == list:
|
||||
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
|
||||
## if it's a 2d list - each element in the list is a text_completion() request
|
||||
if len(prompt) > 0 and type(prompt[0]) == list:
|
||||
responses = [None for x in prompt] # init responses
|
||||
for i, request in enumerate(prompt):
|
||||
decoded_prompt = tokenizer.decode(request)
|
||||
# print("\ndecoded\n", decoded_prompt)
|
||||
# print("type decoded", type(decoded_prompt))
|
||||
new_kwargs = copy.deepcopy(kwargs)
|
||||
new_kwargs["prompt"] = decoded_prompt
|
||||
# print("making new individual request", new_kwargs)
|
||||
response = text_completion(**new_kwargs)
|
||||
# print("assigning for ", i)
|
||||
responses[i] = response["choices"][0]
|
||||
print(responses)
|
||||
formatted_response_obj = {
|
||||
"id": response["id"],
|
||||
"object": "text_completion",
|
||||
"created": response["created"],
|
||||
"model": response["model"],
|
||||
"choices": responses,
|
||||
"usage": response["usage"]
|
||||
}
|
||||
return formatted_response_obj
|
||||
else:
|
||||
messages = [{"role": "system", "content": kwargs["prompt"]}]
|
||||
kwargs["messages"] = messages
|
||||
kwargs.pop("prompt")
|
||||
response = completion(*args, **kwargs) # assume the response is the openai response object
|
||||
|
||||
# if the model is text-davinci-003, return raw response from openai
|
||||
if kwargs["model"] in litellm.open_ai_text_completion_models and response._hidden_params.get("original_response", None) != None:
|
||||
return response._hidden_params.get("original_response", None)
|
||||
transformed_logprobs = None
|
||||
# only supported for TGI models
|
||||
try:
|
||||
|
@ -1912,8 +1949,6 @@ def text_completion(*args, **kwargs):
|
|||
"usage": response["usage"]
|
||||
}
|
||||
return formatted_response_obj
|
||||
else:
|
||||
raise ValueError("please pass prompt into the `text_completion` endpoint - `text_completion(model, prompt='hello world')`")
|
||||
|
||||
##### Moderation #######################
|
||||
def moderation(input: str, api_key: Optional[str]=None):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue