fix(text_completion.py): fix routing logic

This commit is contained in:
Krrish Dholakia 2023-11-10 15:46:37 -08:00
parent 11d1651d36
commit 54b4130d54

View file

@ -1929,96 +1929,80 @@ def text_completion(
# get custom_llm_provider # get custom_llm_provider
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) _, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
if custom_llm_provider == "huggingface":
if custom_llm_provider == "text-completion-openai":
# text-davinci-003 and openai text completion models
messages = [{"role": "system", "content": prompt}]
kwargs.pop("prompt", None)
response = completion(
model = model,
messages=messages,
*args,
**kwargs,
**optional_params
)
# assume the response is the openai response object
# return raw response from openai
return response._hidden_params.get("original_response", None)
elif custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3 # if echo == True, for TGI llms we need to set top_n_tokens to 3
if echo == True: if echo == True:
# for tgi llms # for tgi llms
if "top_n_tokens" not in kwargs: if "top_n_tokens" not in kwargs:
kwargs["top_n_tokens"] = 3 kwargs["top_n_tokens"] = 3
# processing prompt - users can pass raw tokens to OpenAI Completion() # processing prompt - users can pass raw tokens to OpenAI Completion()
if type(prompt) == list: if type(prompt) == list:
import concurrent.futures import concurrent.futures
tokenizer = tiktoken.encoding_for_model("text-davinci-003") tokenizer = tiktoken.encoding_for_model("text-davinci-003")
## if it's a 2d list - each element in the list is a text_completion() request ## if it's a 2d list - each element in the list is a text_completion() request
if len(prompt) > 0 and type(prompt[0]) == list: if len(prompt) > 0 and type(prompt[0]) == list:
responses = [None for x in prompt] # init responses responses = [None for x in prompt] # init responses
def process_prompt(i, individual_prompt): def process_prompt(i, individual_prompt):
decoded_prompt = tokenizer.decode(individual_prompt) decoded_prompt = tokenizer.decode(individual_prompt)
all_params = {**kwargs, **optional_params} all_params = {**kwargs, **optional_params}
response = text_completion( response = text_completion(
model=model, model=model,
prompt=decoded_prompt, prompt=decoded_prompt,
num_retries=3,# ensure this does not fail for the batch num_retries=3,# ensure this does not fail for the batch
*args, *args,
**all_params, **all_params,
) )
#print(response) #print(response)
text_completion_response["id"] = response.get("id", None) text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion" text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None) text_completion_response["created"] = response.get("created", None)
text_completion_response["model"] = response.get("model", None) text_completion_response["model"] = response.get("model", None)
return response["choices"][0] return response["choices"][0]
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)] futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)]
for i, future in enumerate(concurrent.futures.as_completed(futures)): for i, future in enumerate(concurrent.futures.as_completed(futures)):
responses[i] = future.result() responses[i] = future.result()
text_completion_response["choices"] = responses text_completion_response["choices"] = responses
return text_completion_response return text_completion_response
else: # else:
# check if non default values passed in for best_of, echo, logprobs, suffix # check if non default values passed in for best_of, echo, logprobs, suffix
# these are the params supported by Completion() but not ChatCompletion # these are the params supported by Completion() but not ChatCompletion
# default case, non OpenAI requests go through here # default case, non OpenAI requests go through here
messages = [{"role": "system", "content": prompt}] messages = [{"role": "system", "content": prompt}]
kwargs.pop("prompt", None) kwargs.pop("prompt", None)
response = completion( response = completion(
model = model, model = model,
messages=messages, messages=messages,
*args, *args,
**kwargs, **kwargs,
**optional_params, **optional_params,
) )
if stream == True or kwargs.get("stream", False) == True: if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model) response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response return response
transformed_logprobs = None transformed_logprobs = None
# only supported for TGI models # only supported for TGI models
try: try:
raw_response = response._hidden_params.get("original_response", None) raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response) transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}") print_verbose(f"LiteLLM non blocking exception: {e}")
text_completion_response["id"] = response.get("id", None) text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion" text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None) text_completion_response["created"] = response.get("created", None)
text_completion_response["model"] = response.get("model", None) text_completion_response["model"] = response.get("model", None)
text_choices = TextChoices() text_choices = TextChoices()
text_choices["text"] = response["choices"][0]["message"]["content"] text_choices["text"] = response["choices"][0]["message"]["content"]
text_choices["index"] = response["choices"][0]["index"] text_choices["index"] = response["choices"][0]["index"]
text_choices["logprobs"] = transformed_logprobs text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = response["choices"][0]["finish_reason"] text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices] text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None) text_completion_response["usage"] = response.get("usage", None)
return text_completion_response return text_completion_response
##### Moderation ####################### ##### Moderation #######################
def moderation(input: str, api_key: Optional[str]=None): def moderation(input: str, api_key: Optional[str]=None):