fix: fix value error if model returns empty completion

This commit is contained in:
Krrish Dholakia 2023-10-10 10:11:21 -07:00
parent 6d81bcc248
commit af2fd0e0de
21 changed files with 84 additions and 50 deletions

View file

@ -268,13 +268,15 @@ def completion(
)
else:
if task == "conversational":
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
elif task == "text-generation-inference":
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
@ -289,18 +291,24 @@ def completion(
sum_logprob = 0
for token in item["tokens"]:
sum_logprob += token["logprob"]
message_obj = Message(content=item["generated_text"], logprobs=sum_logprob)
if len(item["generated_text"]) > 0:
message_obj = Message(content=item["generated_text"], logprobs=sum_logprob)
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
else:
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(input_text)
) ##[TODO] use the llama2 tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the llama2 tokenizer here
model_response["created"] = time.time()