(feat) support logprobs, top_logprobs openai

This commit is contained in:
ishaan-jaff 2023-12-26 13:59:27 +05:30
parent 871f207124
commit 7b097305c1
2 changed files with 24 additions and 2 deletions

View file

@ -423,6 +423,8 @@ def completion(
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
]
litellm_params = [
"metadata",
@ -572,6 +574,8 @@ def completion(
tools=tools,
tool_choice=tool_choice,
max_retries=max_retries,
logprobs=logprobs,
top_logprobs=top_logprobs,
**non_default_params,
)

View file

@ -240,7 +240,9 @@ class Delta(OpenAIObject):
class Choices(OpenAIObject):
def __init__(self, finish_reason=None, index=0, message=None, **params):
def __init__(
self, finish_reason=None, index=0, message=None, logprobs=None, **params
):
super(Choices, self).__init__(**params)
self.finish_reason = (
map_finish_reason(finish_reason) or "stop"
@ -250,6 +252,8 @@ class Choices(OpenAIObject):
self.message = Message(content=None)
else:
self.message = message
if logprobs is not None:
self.logprobs = logprobs
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -2840,6 +2844,8 @@ def get_optional_params( # use the openai defaults
tools=None,
tool_choice=None,
max_retries=None,
logprobs=None,
top_logprobs=None,
**kwargs,
):
# retrieve all parameters passed to the function
@ -2867,6 +2873,8 @@ def get_optional_params( # use the openai defaults
"tools": None,
"tool_choice": None,
"max_retries": None,
"logprobs": None,
"top_logprobs": None,
}
# filter out those parameters that were passed with non-default values
non_default_params = {
@ -3615,6 +3623,8 @@ def get_optional_params( # use the openai defaults
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
]
_check_valid_arg(supported_params=supported_params)
if functions is not None:
@ -3651,6 +3661,10 @@ def get_optional_params( # use the openai defaults
optional_params["tool_choice"] = tool_choice
if max_retries is not None:
optional_params["max_retries"] = max_retries
if logprobs is not None:
optional_params["logprobs"] = logprobs
if top_logprobs is not None:
optional_params["top_logprobs"] = top_logprobs
optional_params = non_default_params
# if user passed in non-default kwargs for specific providers/models, pass them along
for k in passed_params.keys():
@ -4703,8 +4717,12 @@ def convert_to_model_response_object(
if finish_reason == None:
# gpt-4 vision can return 'finish_reason' or 'finish_details'
finish_reason = choice.get("finish_details")
logprobs = choice.get("logprobs", None)
choice = Choices(
finish_reason=finish_reason, index=idx, message=message
finish_reason=finish_reason,
index=idx,
message=message,
logprobs=logprobs,
)
choice_list.append(choice)
model_response_object.choices = choice_list