(feat) add transform_logprobs for text_completion

This commit is contained in:
ishaan-jaff 2023-11-01 18:24:42 -07:00
parent 8ca7af3a63
commit 8b5ee89d82

View file

@ -4369,7 +4369,7 @@ def transform_logprobs(hf_response):
for response in hf_response:
# Extract the relevant information from the response
response_details = response['details']
tokens = response_details['prefill'] + response_details['tokens']
top_tokens = response_details['top_tokens']
# Initialize an empty list for the token information
token_info = {
@ -4379,11 +4379,7 @@ def transform_logprobs(hf_response):
'top_logprobs': [],
}
stub_top_logprobs = { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(tokens):
for i, token in enumerate(response_details['prefill']):
# Extract the text of the token
token_text = token['text']
@ -4393,12 +4389,37 @@ def transform_logprobs(hf_response):
# Add the token information to the 'token_info' list
token_info['tokens'].append(token_text)
token_info['token_logprobs'].append(token_logprob)
token_info['top_logprobs'].append(stub_top_logprobs)
# stub this to work with llm eval harness
top_alt_tokens = { "": -1, "": -2, "": -3 }
token_info['top_logprobs'].append(top_alt_tokens)
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details['tokens']):
# Extract the text of the token
token_text = token['text']
# Extract the logprob of the token
token_logprob = token['logprob']
top_alt_tokens = {}
temp_top_logprobs = top_tokens[i]
# top_alt_tokens should look like this: { "alternative_1": -1, "alternative_2": -2, "alternative_3": -3 }
for elem in temp_top_logprobs:
text = elem["text"]
logprob = elem["logprob"]
top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list
token_info['tokens'].append(token_text)
token_info['token_logprobs'].append(token_logprob)
token_info['top_logprobs'].append(top_alt_tokens)
# Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens
token_info['text_offset'].append(sum(len(t['text']) for t in tokens[:i]))
token_info['text_offset'].append(sum(len(t['text']) for t in response_details['tokens'][:i]))
# Add the 'token_info' list to the 'transformed_logprobs' list
transformed_logprobs = token_info