fix(utils.py): add logprobs support for together ai

Fixes

https://github.com/BerriAI/litellm/issues/6724
This commit is contained in:
Krrish Dholakia 2024-11-13 12:26:06 +05:30
parent 70c8be59d7
commit b2f1e47104
4 changed files with 21 additions and 31 deletions

View file

@ -161,17 +161,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params() return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
return [ return litellm.TogetherAIConfig().get_supported_openai_params(model=model)
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"tools",
"tool_choice",
"response_format",
]
elif custom_llm_provider == "ai21": elif custom_llm_provider == "ai21":
return [ return [
"stream", "stream",

View file

@ -6,8 +6,8 @@ Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
Docs: https://docs.together.ai/reference/completions-1 Docs: https://docs.together.ai/reference/completions-1
""" """
from ..OpenAI.openai import OpenAIConfig from ..OpenAI.chat.gpt_transformation import OpenAIGPTConfig
class TogetherAIConfig(OpenAIConfig): class TogetherAIConfig(OpenAIGPTConfig):
pass pass

View file

@ -2900,24 +2900,16 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if stream: optional_params = litellm.TogetherAIConfig().map_openai_params(
optional_params["stream"] = stream non_default_params=non_default_params,
if temperature is not None: optional_params=optional_params,
optional_params["temperature"] = temperature model=model,
if top_p is not None: drop_params=(
optional_params["top_p"] = top_p drop_params
if max_tokens is not None: if drop_params is not None and isinstance(drop_params, bool)
optional_params["max_tokens"] = max_tokens else False
if frequency_penalty is not None: ),
optional_params["frequency_penalty"] = frequency_penalty )
if stop is not None:
optional_params["stop"] = stop
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
if response_format is not None:
optional_params["response_format"] = response_format
elif custom_llm_provider == "ai21": elif custom_llm_provider == "ai21":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(

View file

@ -921,3 +921,11 @@ def test_watsonx_text_top_k():
) )
print(optional_params) print(optional_params)
assert optional_params["top_k"] == 10 assert optional_params["top_k"] == 10
def test_together_ai_model_params():
optional_params = get_optional_params(
model="together_ai", custom_llm_provider="together_ai", logprobs=1
)
print(optional_params)
assert optional_params["logprobs"] == 1