mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(text_completion.py): fix routing logic
This commit is contained in:
parent
11d1651d36
commit
54b4130d54
1 changed files with 66 additions and 82 deletions
148
litellm/main.py
148
litellm/main.py
|
@ -1929,96 +1929,80 @@ def text_completion(
|
||||||
# get custom_llm_provider
|
# get custom_llm_provider
|
||||||
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
|
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
|
||||||
|
|
||||||
|
if custom_llm_provider == "huggingface":
|
||||||
if custom_llm_provider == "text-completion-openai":
|
|
||||||
# text-davinci-003 and openai text completion models
|
|
||||||
messages = [{"role": "system", "content": prompt}]
|
|
||||||
kwargs.pop("prompt", None)
|
|
||||||
response = completion(
|
|
||||||
model = model,
|
|
||||||
messages=messages,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
**optional_params
|
|
||||||
)
|
|
||||||
# assume the response is the openai response object
|
|
||||||
# return raw response from openai
|
|
||||||
return response._hidden_params.get("original_response", None)
|
|
||||||
|
|
||||||
elif custom_llm_provider == "huggingface":
|
|
||||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||||
if echo == True:
|
if echo == True:
|
||||||
# for tgi llms
|
# for tgi llms
|
||||||
if "top_n_tokens" not in kwargs:
|
if "top_n_tokens" not in kwargs:
|
||||||
kwargs["top_n_tokens"] = 3
|
kwargs["top_n_tokens"] = 3
|
||||||
|
|
||||||
# processing prompt - users can pass raw tokens to OpenAI Completion()
|
# processing prompt - users can pass raw tokens to OpenAI Completion()
|
||||||
if type(prompt) == list:
|
if type(prompt) == list:
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
|
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
|
||||||
## if it's a 2d list - each element in the list is a text_completion() request
|
## if it's a 2d list - each element in the list is a text_completion() request
|
||||||
if len(prompt) > 0 and type(prompt[0]) == list:
|
if len(prompt) > 0 and type(prompt[0]) == list:
|
||||||
responses = [None for x in prompt] # init responses
|
responses = [None for x in prompt] # init responses
|
||||||
def process_prompt(i, individual_prompt):
|
def process_prompt(i, individual_prompt):
|
||||||
decoded_prompt = tokenizer.decode(individual_prompt)
|
decoded_prompt = tokenizer.decode(individual_prompt)
|
||||||
all_params = {**kwargs, **optional_params}
|
all_params = {**kwargs, **optional_params}
|
||||||
response = text_completion(
|
response = text_completion(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=decoded_prompt,
|
prompt=decoded_prompt,
|
||||||
num_retries=3,# ensure this does not fail for the batch
|
num_retries=3,# ensure this does not fail for the batch
|
||||||
*args,
|
*args,
|
||||||
**all_params,
|
**all_params,
|
||||||
)
|
)
|
||||||
#print(response)
|
#print(response)
|
||||||
text_completion_response["id"] = response.get("id", None)
|
text_completion_response["id"] = response.get("id", None)
|
||||||
text_completion_response["object"] = "text_completion"
|
text_completion_response["object"] = "text_completion"
|
||||||
text_completion_response["created"] = response.get("created", None)
|
text_completion_response["created"] = response.get("created", None)
|
||||||
text_completion_response["model"] = response.get("model", None)
|
text_completion_response["model"] = response.get("model", None)
|
||||||
return response["choices"][0]
|
return response["choices"][0]
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)]
|
futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)]
|
||||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
responses[i] = future.result()
|
responses[i] = future.result()
|
||||||
text_completion_response["choices"] = responses
|
text_completion_response["choices"] = responses
|
||||||
|
|
||||||
return text_completion_response
|
return text_completion_response
|
||||||
else:
|
# else:
|
||||||
# check if non default values passed in for best_of, echo, logprobs, suffix
|
# check if non default values passed in for best_of, echo, logprobs, suffix
|
||||||
# these are the params supported by Completion() but not ChatCompletion
|
# these are the params supported by Completion() but not ChatCompletion
|
||||||
|
|
||||||
# default case, non OpenAI requests go through here
|
# default case, non OpenAI requests go through here
|
||||||
messages = [{"role": "system", "content": prompt}]
|
messages = [{"role": "system", "content": prompt}]
|
||||||
kwargs.pop("prompt", None)
|
kwargs.pop("prompt", None)
|
||||||
response = completion(
|
response = completion(
|
||||||
model = model,
|
model = model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
if stream == True or kwargs.get("stream", False) == True:
|
if stream == True or kwargs.get("stream", False) == True:
|
||||||
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
# only supported for TGI models
|
# only supported for TGI models
|
||||||
try:
|
try:
|
||||||
raw_response = response._hidden_params.get("original_response", None)
|
raw_response = response._hidden_params.get("original_response", None)
|
||||||
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM non blocking exception: {e}")
|
print_verbose(f"LiteLLM non blocking exception: {e}")
|
||||||
text_completion_response["id"] = response.get("id", None)
|
text_completion_response["id"] = response.get("id", None)
|
||||||
text_completion_response["object"] = "text_completion"
|
text_completion_response["object"] = "text_completion"
|
||||||
text_completion_response["created"] = response.get("created", None)
|
text_completion_response["created"] = response.get("created", None)
|
||||||
text_completion_response["model"] = response.get("model", None)
|
text_completion_response["model"] = response.get("model", None)
|
||||||
text_choices = TextChoices()
|
text_choices = TextChoices()
|
||||||
text_choices["text"] = response["choices"][0]["message"]["content"]
|
text_choices["text"] = response["choices"][0]["message"]["content"]
|
||||||
text_choices["index"] = response["choices"][0]["index"]
|
text_choices["index"] = response["choices"][0]["index"]
|
||||||
text_choices["logprobs"] = transformed_logprobs
|
text_choices["logprobs"] = transformed_logprobs
|
||||||
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
||||||
text_completion_response["choices"] = [text_choices]
|
text_completion_response["choices"] = [text_choices]
|
||||||
text_completion_response["usage"] = response.get("usage", None)
|
text_completion_response["usage"] = response.get("usage", None)
|
||||||
return text_completion_response
|
return text_completion_response
|
||||||
|
|
||||||
##### Moderation #######################
|
##### Moderation #######################
|
||||||
def moderation(input: str, api_key: Optional[str]=None):
|
def moderation(input: str, api_key: Optional[str]=None):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue