fix(utils.py): add logprobs support for together ai

Fixes

https://github.com/BerriAI/litellm/issues/6724
This commit is contained in:
Krrish Dholakia 2024-11-13 12:26:06 +05:30
parent 70c8be59d7
commit b2f1e47104
4 changed files with 21 additions and 31 deletions

View file

@ -161,17 +161,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"tools",
"tool_choice",
"response_format",
]
return litellm.TogetherAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ai21":
return [
"stream",

View file

@ -6,8 +6,8 @@ Calls done in OpenAI/openai.py as TogetherAI is openai-compatible.
Docs: https://docs.together.ai/reference/completions-1
"""
from ..OpenAI.openai import OpenAIConfig
from ..OpenAI.chat.gpt_transformation import OpenAIGPTConfig
class TogetherAIConfig(OpenAIConfig):
class TogetherAIConfig(OpenAIGPTConfig):
pass

View file

@ -2900,24 +2900,16 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if stop is not None:
optional_params["stop"] = stop
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
if response_format is not None:
optional_params["response_format"] = response_format
optional_params = litellm.TogetherAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "ai21":
## check if unsupported param passed in
supported_params = get_supported_openai_params(

View file

@ -921,3 +921,11 @@ def test_watsonx_text_top_k():
)
print(optional_params)
assert optional_params["top_k"] == 10
def test_together_ai_model_params():
optional_params = get_optional_params(
model="together_ai", custom_llm_provider="together_ai", logprobs=1
)
print(optional_params)
assert optional_params["logprobs"] == 1