return all best of sequences

This commit is contained in:
Krrish Dholakia 2023-09-20 14:43:25 -07:00
parent 93c41e8f6d
commit c63db48652
5 changed files with 41 additions and 12 deletions

View file

@ -5,7 +5,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse, Choices, Message
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -173,16 +173,28 @@ def completion(
"content" "content"
] = completion_response["generated_text"] ] = completion_response["generated_text"]
elif task == "text-generation-inference": elif task == "text-generation-inference":
model_response["choices"][0]["message"][ if "best_of" in optional_params and optional_params["best_of"] > 1:
"content" if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
] = completion_response[0]["generated_text"] choices_list = []
## GETTING LOGPROBS + FINISH REASON for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: sum_logprob = 0
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] for token in item["tokens"]:
sum_logprob = 0 sum_logprob += token["logprob"]
for token in completion_response[0]["details"]["tokens"]: message_obj = Message(content=item["generated_text"], logprobs=sum_logprob)
sum_logprob += token["logprob"] choice_obj = Choices(finish_reason=item["finish_reason"], index=idx, message=message_obj)
model_response["choices"][0]["message"]["logprobs"] = sum_logprob choices_list.append(choice_obj)
model_response["choices"] = choices_list
else:
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
else: else:
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
## CALCULATING USAGE ## CALCULATING USAGE

View file

@ -135,6 +135,22 @@ def test_completion_with_litellm_call_id():
# test_completion_hf_api() # test_completion_hf_api()
# def test_completion_hf_api_best_of():
# # failing on circle ci commenting out
# try:
# user_message = "write some code to find the sum of two numbers"
# messages = [{ "content": user_message,"role": "user"}]
# api_base = "https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud"
# response = completion(model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base=api_base, n=2)
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# if "loading" in str(e):
# pass
# pytest.fail(f"Error occurred: {e}")
# test_completion_hf_api_best_of()
# def test_completion_hf_deployed_api(): # def test_completion_hf_deployed_api():
# try: # try:
# user_message = "There's a llama in my garden 😱 What should I do?" # user_message = "There's a llama in my garden 😱 What should I do?"

View file

@ -902,7 +902,8 @@ def get_optional_params( # use the openai defaults
if top_p != 1: if top_p != 1:
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if n != 1: if n != 1:
optional_params["n"] = n optional_params["best_of"] = n
optional_params["do_sample"] = True # need to sample if you want best of for hf inference endpoints
if stream: if stream:
optional_params["stream"] = stream optional_params["stream"] = stream
if stop != None: if stop != None: