diff --git a/litellm/main.py b/litellm/main.py index 55918f3a12..9aed35f833 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1861,34 +1861,71 @@ def embedding( ###### Text Completion ################ def text_completion(*args, **kwargs): + import copy """ This maps to the Openai.Completion.create format, which has a different I/O (accepts prompt, returning ["choices"]["text"]. """ if "engine" in kwargs: kwargs["model"] = kwargs["engine"] kwargs.pop("engine") - - # if echo == True, for TGI llms we need to set top_n_tokens to 3 - if kwargs.get("echo", False) == True and (kwargs["model"] not in litellm.open_ai_text_completion_models): - # for tgi llms - if "top_n_tokens" not in kwargs: - kwargs["top_n_tokens"] = 3 - if "prompt" in kwargs: - if type(kwargs["prompt"]) == list: - new_prompt = "" - for chunk in kwargs["prompt"]: - decoded = litellm.utils.decode(model="gpt2", tokens=chunk) - new_prompt+= decoded - kwargs["prompt"] = new_prompt + # input validation + if "prompt" not in kwargs: + raise ValueError("please pass prompt into the `text_completion` endpoint - `text_completion(model, prompt='hello world')`") + + model = kwargs["model"] + prompt = kwargs["prompt"] + # get custom_llm_provider + _, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model) + + if custom_llm_provider == "text-completion-openai": + # text-davinci-003 and openai text completion models + messages = [{"role": "system", "content": kwargs["prompt"]}] + kwargs["messages"] = messages + kwargs.pop("prompt") + response = completion(*args, **kwargs) # assume the response is the openai response object + # return raw response from openai + return response._hidden_params.get("original_response", None) + + elif custom_llm_provider == "huggingface": + # if echo == True, for TGI llms we need to set top_n_tokens to 3 + if kwargs.get("echo", False) == True: + # for tgi llms + if "top_n_tokens" not in kwargs: + kwargs["top_n_tokens"] = 3 + + # processing prompt - users can pass raw tokens to OpenAI Completion() + if type(prompt) == list: + tokenizer = tiktoken.encoding_for_model("text-davinci-003") + ## if it's a 2d list - each element in the list is a text_completion() request + if len(prompt) > 0 and type(prompt[0]) == list: + responses = [None for x in prompt] # init responses + for i, request in enumerate(prompt): + decoded_prompt = tokenizer.decode(request) + # print("\ndecoded\n", decoded_prompt) + # print("type decoded", type(decoded_prompt)) + new_kwargs = copy.deepcopy(kwargs) + new_kwargs["prompt"] = decoded_prompt + # print("making new individual request", new_kwargs) + response = text_completion(**new_kwargs) + # print("assigning for ", i) + responses[i] = response["choices"][0] + print(responses) + formatted_response_obj = { + "id": response["id"], + "object": "text_completion", + "created": response["created"], + "model": response["model"], + "choices": responses, + "usage": response["usage"] + } + return formatted_response_obj + else: messages = [{"role": "system", "content": kwargs["prompt"]}] kwargs["messages"] = messages kwargs.pop("prompt") response = completion(*args, **kwargs) # assume the response is the openai response object - # if the model is text-davinci-003, return raw response from openai - if kwargs["model"] in litellm.open_ai_text_completion_models and response._hidden_params.get("original_response", None) != None: - return response._hidden_params.get("original_response", None) transformed_logprobs = None # only supported for TGI models try: @@ -1912,8 +1949,6 @@ def text_completion(*args, **kwargs): "usage": response["usage"] } return formatted_response_obj - else: - raise ValueError("please pass prompt into the `text_completion` endpoint - `text_completion(model, prompt='hello world')`") ##### Moderation ####################### def moderation(input: str, api_key: Optional[str]=None):