diff --git a/litellm/main.py b/litellm/main.py index 9a1e83459..8dd9f83b9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -423,6 +423,8 @@ def completion( "tools", "tool_choice", "max_retries", + "logprobs", + "top_logprobs", ] litellm_params = [ "metadata", @@ -572,6 +574,8 @@ def completion( tools=tools, tool_choice=tool_choice, max_retries=max_retries, + logprobs=logprobs, + top_logprobs=top_logprobs, **non_default_params, ) diff --git a/litellm/utils.py b/litellm/utils.py index c142296fb..e5378a60d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -240,7 +240,9 @@ class Delta(OpenAIObject): class Choices(OpenAIObject): - def __init__(self, finish_reason=None, index=0, message=None, **params): + def __init__( + self, finish_reason=None, index=0, message=None, logprobs=None, **params + ): super(Choices, self).__init__(**params) self.finish_reason = ( map_finish_reason(finish_reason) or "stop" @@ -250,6 +252,8 @@ class Choices(OpenAIObject): self.message = Message(content=None) else: self.message = message + if logprobs is not None: + self.logprobs = logprobs def __contains__(self, key): # Define custom behavior for the 'in' operator @@ -2840,6 +2844,8 @@ def get_optional_params( # use the openai defaults tools=None, tool_choice=None, max_retries=None, + logprobs=None, + top_logprobs=None, **kwargs, ): # retrieve all parameters passed to the function @@ -2867,6 +2873,8 @@ def get_optional_params( # use the openai defaults "tools": None, "tool_choice": None, "max_retries": None, + "logprobs": None, + "top_logprobs": None, } # filter out those parameters that were passed with non-default values non_default_params = { @@ -3615,6 +3623,8 @@ def get_optional_params( # use the openai defaults "tools", "tool_choice", "max_retries", + "logprobs", + "top_logprobs", ] _check_valid_arg(supported_params=supported_params) if functions is not None: @@ -3651,6 +3661,10 @@ def get_optional_params( # use the openai defaults optional_params["tool_choice"] = tool_choice if max_retries is not None: optional_params["max_retries"] = max_retries + if logprobs is not None: + optional_params["logprobs"] = logprobs + if top_logprobs is not None: + optional_params["top_logprobs"] = top_logprobs optional_params = non_default_params # if user passed in non-default kwargs for specific providers/models, pass them along for k in passed_params.keys(): @@ -4703,8 +4717,12 @@ def convert_to_model_response_object( if finish_reason == None: # gpt-4 vision can return 'finish_reason' or 'finish_details' finish_reason = choice.get("finish_details") + logprobs = choice.get("logprobs", None) choice = Choices( - finish_reason=finish_reason, index=idx, message=message + finish_reason=finish_reason, + index=idx, + message=message, + logprobs=logprobs, ) choice_list.append(choice) model_response_object.choices = choice_list