From b2f1e471047b602f0ec8e3ff91f864c91c499bf0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 13 Nov 2024 12:26:06 +0530 Subject: [PATCH] fix(utils.py): add logprobs support for together ai Fixes https://github.com/BerriAI/litellm/issues/6724 --- .../get_supported_openai_params.py | 12 +------- litellm/llms/together_ai/chat.py | 4 +-- litellm/utils.py | 28 +++++++------------ tests/llm_translation/test_optional_params.py | 8 ++++++ 4 files changed, 21 insertions(+), 31 deletions(-) diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index bb94d54d5..05b4b9c48 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -161,17 +161,7 @@ def get_supported_openai_params( # noqa: PLR0915 elif custom_llm_provider == "huggingface": return litellm.HuggingfaceConfig().get_supported_openai_params() elif custom_llm_provider == "together_ai": - return [ - "stream", - "temperature", - "max_tokens", - "top_p", - "stop", - "frequency_penalty", - "tools", - "tool_choice", - "response_format", - ] + return litellm.TogetherAIConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "ai21": return [ "stream", diff --git a/litellm/llms/together_ai/chat.py b/litellm/llms/together_ai/chat.py index 398bc489c..cb12d6147 100644 --- a/litellm/llms/together_ai/chat.py +++ b/litellm/llms/together_ai/chat.py @@ -6,8 +6,8 @@ Calls done in OpenAI/openai.py as TogetherAI is openai-compatible. Docs: https://docs.together.ai/reference/completions-1 """ -from ..OpenAI.openai import OpenAIConfig +from ..OpenAI.chat.gpt_transformation import OpenAIGPTConfig -class TogetherAIConfig(OpenAIConfig): +class TogetherAIConfig(OpenAIGPTConfig): pass diff --git a/litellm/utils.py b/litellm/utils.py index 802bcfc04..a0f544312 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2900,24 +2900,16 @@ def get_optional_params( # noqa: PLR0915 ) _check_valid_arg(supported_params=supported_params) - if stream: - optional_params["stream"] = stream - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - if frequency_penalty is not None: - optional_params["frequency_penalty"] = frequency_penalty - if stop is not None: - optional_params["stop"] = stop - if tools is not None: - optional_params["tools"] = tools - if tool_choice is not None: - optional_params["tool_choice"] = tool_choice - if response_format is not None: - optional_params["response_format"] = response_format + optional_params = litellm.TogetherAIConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) elif custom_llm_provider == "ai21": ## check if unsupported param passed in supported_params = get_supported_openai_params( diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index 7283e9a39..8677d6b73 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -921,3 +921,11 @@ def test_watsonx_text_top_k(): ) print(optional_params) assert optional_params["top_k"] == 10 + + +def test_together_ai_model_params(): + optional_params = get_optional_params( + model="together_ai", custom_llm_provider="together_ai", logprobs=1 + ) + print(optional_params) + assert optional_params["logprobs"] == 1