return logprobs for hf models

This commit is contained in:
Krrish Dholakia 2023-08-30 15:16:24 -07:00
parent 271092f541
commit daa949a539
5 changed files with 46 additions and 20 deletions

View file

@ -68,11 +68,7 @@ class HuggingfaceRestAPILLM:
for message in messages:
prompt += f"{message['content']}"
### MAP INPUT PARAMS
# max tokens
if "max_tokens" in optional_params:
value = optional_params.pop("max_tokens")
optional_params["max_new_tokens"] = value
data = {"inputs": prompt, "parameters": optional_params}
data = {"inputs": prompt, "parameters": optional_params, "stream": True if "stream" in optional_params and optional_params["stream"] == True else False}
## LOGGING
self.logging_obj.pre_call(
input=prompt,
@ -80,12 +76,18 @@ class HuggingfaceRestAPILLM:
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=self.headers, data=json.dumps(data)
)
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
completion_url, headers=self.headers, data=json.dumps(data), stream=optional_params["stream"]
)
return response.iter_lines()
else:
print(f"completion url: {completion_url}")
print(f"headers: {self.headers}")
print(f"data: {data}")
response = requests.post(
completion_url, headers=self.headers, data=json.dumps(data)
)
## LOGGING
self.logging_obj.post_call(
input=prompt,
@ -110,7 +112,13 @@ class HuggingfaceRestAPILLM:
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
## CALCULATING USAGE
prompt_tokens = len(
self.encoding.encode(prompt)