mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
return all best of sequences
This commit is contained in:
parent
93c41e8f6d
commit
c63db48652
5 changed files with 41 additions and 12 deletions
|
@ -5,7 +5,7 @@ from enum import Enum
|
|||
import requests
|
||||
import time
|
||||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
from litellm.utils import ModelResponse, Choices, Message
|
||||
from typing import Optional
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
@ -173,16 +173,28 @@ def completion(
|
|||
"content"
|
||||
] = completion_response["generated_text"]
|
||||
elif task == "text-generation-inference":
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
## GETTING LOGPROBS + FINISH REASON
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
|
||||
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
|
||||
sum_logprob = 0
|
||||
for token in item["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
message_obj = Message(content=item["generated_text"], logprobs=sum_logprob)
|
||||
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
else:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
## GETTING LOGPROBS + FINISH REASON
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
|
||||
## CALCULATING USAGE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue