mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
huggingface conversational task support
This commit is contained in:
parent
1913d36e05
commit
5b6b9a9fab
7 changed files with 75 additions and 40 deletions
|
@ -40,7 +40,9 @@ def completion(
|
|||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
task = optional_params.pop("task")
|
||||
completion_url = ""
|
||||
input_text = None
|
||||
if "https" in model:
|
||||
completion_url = model
|
||||
elif api_base:
|
||||
|
@ -49,37 +51,62 @@ def completion(
|
|||
completion_url = os.getenv("HF_API_BASE", "")
|
||||
else:
|
||||
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
### MAP INPUT PARAMS
|
||||
if "https://api-inference.huggingface.co/models" in completion_url:
|
||||
if task == "conversational":
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("details")
|
||||
past_user_inputs = []
|
||||
generated_responses = []
|
||||
text = ""
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
if text != "":
|
||||
past_user_inputs.append(text)
|
||||
text = message["content"]
|
||||
elif message["role"] == "assistant" or message["role"] == "system":
|
||||
generated_responses.append(message["content"])
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||
"inputs": {
|
||||
"text": text,
|
||||
"past_user_inputs": past_user_inputs,
|
||||
"generated_responses": generated_responses
|
||||
},
|
||||
"parameters": inference_params
|
||||
}
|
||||
input_text = "".join(message["content"] for message in messages)
|
||||
elif task == "text-generation-inference":
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
if "https://api-inference.huggingface.co/models" in completion_url:
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("details")
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||
}
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
input=input_text,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
additional_args={"complete_input_dict": data, "task": task},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
|
@ -98,10 +125,10 @@ def completion(
|
|||
)
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
input=input_text,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
additional_args={"complete_input_dict": data, "task": task},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
|
@ -119,19 +146,23 @@ def completion(
|
|||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
|
||||
## GETTING LOGPROBS
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
|
||||
if task == "conversational":
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["generated_text"]
|
||||
elif task == "text-generation-inference":
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
## GETTING LOGPROBS
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
encoding.encode(input_text)
|
||||
) ##[TODO] use the llama2 tokenizer here
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue