forked from phoenix/litellm-mirror
(feat) support logprobs, top_logprobs openai
This commit is contained in:
parent
871f207124
commit
7b097305c1
2 changed files with 24 additions and 2 deletions
|
@ -423,6 +423,8 @@ def completion(
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
]
|
]
|
||||||
litellm_params = [
|
litellm_params = [
|
||||||
"metadata",
|
"metadata",
|
||||||
|
@ -572,6 +574,8 @@ def completion(
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -240,7 +240,9 @@ class Delta(OpenAIObject):
|
||||||
|
|
||||||
|
|
||||||
class Choices(OpenAIObject):
|
class Choices(OpenAIObject):
|
||||||
def __init__(self, finish_reason=None, index=0, message=None, **params):
|
def __init__(
|
||||||
|
self, finish_reason=None, index=0, message=None, logprobs=None, **params
|
||||||
|
):
|
||||||
super(Choices, self).__init__(**params)
|
super(Choices, self).__init__(**params)
|
||||||
self.finish_reason = (
|
self.finish_reason = (
|
||||||
map_finish_reason(finish_reason) or "stop"
|
map_finish_reason(finish_reason) or "stop"
|
||||||
|
@ -250,6 +252,8 @@ class Choices(OpenAIObject):
|
||||||
self.message = Message(content=None)
|
self.message = Message(content=None)
|
||||||
else:
|
else:
|
||||||
self.message = message
|
self.message = message
|
||||||
|
if logprobs is not None:
|
||||||
|
self.logprobs = logprobs
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
@ -2840,6 +2844,8 @@ def get_optional_params( # use the openai defaults
|
||||||
tools=None,
|
tools=None,
|
||||||
tool_choice=None,
|
tool_choice=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
|
logprobs=None,
|
||||||
|
top_logprobs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# retrieve all parameters passed to the function
|
# retrieve all parameters passed to the function
|
||||||
|
@ -2867,6 +2873,8 @@ def get_optional_params( # use the openai defaults
|
||||||
"tools": None,
|
"tools": None,
|
||||||
"tool_choice": None,
|
"tool_choice": None,
|
||||||
"max_retries": None,
|
"max_retries": None,
|
||||||
|
"logprobs": None,
|
||||||
|
"top_logprobs": None,
|
||||||
}
|
}
|
||||||
# filter out those parameters that were passed with non-default values
|
# filter out those parameters that were passed with non-default values
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -3615,6 +3623,8 @@ def get_optional_params( # use the openai defaults
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
]
|
]
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
if functions is not None:
|
if functions is not None:
|
||||||
|
@ -3651,6 +3661,10 @@ def get_optional_params( # use the openai defaults
|
||||||
optional_params["tool_choice"] = tool_choice
|
optional_params["tool_choice"] = tool_choice
|
||||||
if max_retries is not None:
|
if max_retries is not None:
|
||||||
optional_params["max_retries"] = max_retries
|
optional_params["max_retries"] = max_retries
|
||||||
|
if logprobs is not None:
|
||||||
|
optional_params["logprobs"] = logprobs
|
||||||
|
if top_logprobs is not None:
|
||||||
|
optional_params["top_logprobs"] = top_logprobs
|
||||||
optional_params = non_default_params
|
optional_params = non_default_params
|
||||||
# if user passed in non-default kwargs for specific providers/models, pass them along
|
# if user passed in non-default kwargs for specific providers/models, pass them along
|
||||||
for k in passed_params.keys():
|
for k in passed_params.keys():
|
||||||
|
@ -4703,8 +4717,12 @@ def convert_to_model_response_object(
|
||||||
if finish_reason == None:
|
if finish_reason == None:
|
||||||
# gpt-4 vision can return 'finish_reason' or 'finish_details'
|
# gpt-4 vision can return 'finish_reason' or 'finish_details'
|
||||||
finish_reason = choice.get("finish_details")
|
finish_reason = choice.get("finish_details")
|
||||||
|
logprobs = choice.get("logprobs", None)
|
||||||
choice = Choices(
|
choice = Choices(
|
||||||
finish_reason=finish_reason, index=idx, message=message
|
finish_reason=finish_reason,
|
||||||
|
index=idx,
|
||||||
|
message=message,
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
choice_list.append(choice)
|
choice_list.append(choice)
|
||||||
model_response_object.choices = choice_list
|
model_response_object.choices = choice_list
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue