fix hf tgi best of bug

This commit is contained in:
Krrish Dholakia 2023-09-20 20:53:32 -07:00
parent 35fda2cd05
commit a8711dc5c2
2 changed files with 12 additions and 13 deletions

View file

@ -173,6 +173,16 @@ def completion(
"content"
] = completion_response["generated_text"]
elif task == "text-generation-inference":
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
choices_list = []
@ -183,18 +193,7 @@ def completion(
message_obj = Message(content=item["generated_text"], logprobs=sum_logprob)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
else:
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
model_response["choices"].extend(choices_list)
else:
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
## CALCULATING USAGE

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.719"
version = "0.1.720"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"