huggingface conversational task support

This commit is contained in:
Krrish Dholakia 2023-09-13 13:44:46 -07:00
parent 1913d36e05
commit 5b6b9a9fab
7 changed files with 75 additions and 40 deletions

View file

@ -40,7 +40,9 @@ def completion(
logger_fn=None,
):
headers = validate_environment(api_key)
task = optional_params.pop("task")
completion_url = ""
input_text = None
if "https" in model:
completion_url = model
elif api_base:
@ -49,37 +51,62 @@ def completion(
completion_url = os.getenv("HF_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
### MAP INPUT PARAMS
if "https://api-inference.huggingface.co/models" in completion_url:
if task == "conversational":
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
past_user_inputs = []
generated_responses = []
text = ""
for message in messages:
if message["role"] == "user":
if text != "":
past_user_inputs.append(text)
text = message["content"]
elif message["role"] == "assistant" or message["role"] == "system":
generated_responses.append(message["content"])
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
}
else:
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses
},
"parameters": inference_params
}
input_text = "".join(message["content"] for message in messages)
elif task == "text-generation-inference":
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
if "https://api-inference.huggingface.co/models" in completion_url:
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
}
else:
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=prompt,
input=input_text,
api_key=api_key,
additional_args={"complete_input_dict": data},
additional_args={"complete_input_dict": data, "task": task},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
@ -98,10 +125,10 @@ def completion(
)
## LOGGING
logging_obj.post_call(
input=prompt,
input=input_text,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
additional_args={"complete_input_dict": data, "task": task},
)
## RESPONSE OBJECT
try:
@ -119,19 +146,23 @@ def completion(
status_code=response.status_code,
)
else:
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
if task == "conversational":
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
elif task == "text-generation-inference":
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
encoding.encode(input_text)
) ##[TODO] use the llama2 tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])