(fix) hf don't fail when logprob is None

This commit is contained in:
ishaan-jaff 2023-11-06 14:22:09 -08:00
parent 8d9e2574cf
commit 3c67de7f04

View file

@ -270,6 +270,7 @@ def completion(
headers=headers, headers=headers,
data=json.dumps(data) data=json.dumps(data)
) )
print(response.text)
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False is_streamed = False
@ -330,7 +331,8 @@ def completion(
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
sum_logprob = 0 sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"] if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob model_response["choices"][0]["message"]._logprob = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1: if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]: if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
@ -338,7 +340,8 @@ def completion(
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]): for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
sum_logprob = 0 sum_logprob = 0
for token in item["tokens"]: for token in item["tokens"]:
sum_logprob += token["logprob"] if token["logprob"] != None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0: if len(item["generated_text"]) > 0:
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob) message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
else: else: