(fix) temp_top_logprobs

This commit is contained in:
ishaan-jaff 2023-11-03 16:45:10 -07:00
parent e29b2e8ce4
commit f3dc06da04

View file

@ -4382,7 +4382,7 @@ def transform_logprobs(hf_response):
for response in hf_response:
# Extract the relevant information from the response
response_details = response['details']
top_tokens = response_details['top_tokens']
top_tokens = response_details.get("top_tokens", {})
# Initialize an empty list for the token information
token_info = {
@ -4417,7 +4417,9 @@ def transform_logprobs(hf_response):
token_logprob = token['logprob']
top_alt_tokens = {}
temp_top_logprobs = top_tokens[i]
temp_top_logprobs = []
if top_tokens != {}:
temp_top_logprobs = top_tokens[i]
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
for elem in temp_top_logprobs: