This commit is contained in:
Hugo Liu 2025-04-24 00:56:39 -07:00 committed by GitHub
commit 2c5b9e07f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 32 additions and 98 deletions

View file

@ -828,7 +828,7 @@ from .llms.maritalk import MaritalkConfig
from .llms.openrouter.chat.transformation import OpenrouterConfig
from .llms.anthropic.chat.transformation import AnthropicConfig
from .llms.anthropic.common_utils import AnthropicModelInfo
from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.groq.stt.transformation import GroqAudioTranscriptionConfig
from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.triton.completion.transformation import TritonConfig
from .llms.triton.completion.transformation import TritonGenerateConfig

View file

@ -79,7 +79,10 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "volcengine":
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "groq":
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
if request_type == "transcription":
return litellm.GroqAudioTranscriptionConfig().get_supported_openai_params(model=model)
else:
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "hosted_vllm":
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "vllm":

View file

@ -2,99 +2,16 @@
Translate from OpenAI's `/v1/audio/transcriptions` to Groq's `/v1/audio/transcriptions`
"""
import types
from typing import List, Optional, Union
from typing import List
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
import litellm
from ...openai.transcriptions.whisper_transformation import (
OpenAIWhisperAudioTranscriptionConfig,
)
class GroqAudioTranscriptionConfig(OpenAIWhisperAudioTranscriptionConfig):
class GroqSTTConfig:
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
tools: Optional[list] = None
tool_choice: Optional[Union[str, dict]] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
tools: Optional[list] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params_stt(self):
return [
"prompt",
"response_format",
"temperature",
"language",
]
def get_supported_openai_response_formats_stt(self) -> List[str]:
return ["json", "verbose_json", "text"]
def map_openai_params_stt(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
response_formats = self.get_supported_openai_response_formats_stt()
for param, value in non_default_params.items():
if param == "response_format":
if value in response_formats:
optional_params[param] = value
else:
if litellm.drop_params is True or drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="Groq doesn't support response_format={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
else:
optional_params[param] = value
return optional_params
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
return ["language", "prompt", "response_format", "temperature", "timestamp_granularities"]

View file

@ -2434,9 +2434,9 @@ def get_optional_params_transcription(
if custom_llm_provider == "openai" or custom_llm_provider == "azure":
optional_params = non_default_params
elif custom_llm_provider == "groq":
supported_params = litellm.GroqSTTConfig().get_supported_openai_params_stt()
supported_params = litellm.GroqAudioTranscriptionConfig().get_supported_openai_params(model=model)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.GroqSTTConfig().map_openai_params_stt(
optional_params = litellm.GroqAudioTranscriptionConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
@ -6629,6 +6629,8 @@ class ProviderConfigManager:
return litellm.OpenAIGPTAudioTranscriptionConfig()
else:
return litellm.OpenAIWhisperAudioTranscriptionConfig()
elif litellm.LlmProviders.GROQ == provider:
return litellm.GroqAudioTranscriptionConfig()
return None
@staticmethod

View file

@ -1,5 +1,6 @@
from base_llm_unit_tests import BaseLLMChatTest
from base_audio_transcription_unit_tests import BaseLLMAudioTranscriptionTest
import litellm
class TestGroq(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
@ -10,3 +11,14 @@ class TestGroq(BaseLLMChatTest):
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
class TestGroqAudioTranscription(BaseLLMAudioTranscriptionTest):
def get_base_audio_transcription_call_args(self) -> dict:
return {
"model": "groq/whisper-large-v3",
"api_base": "https://api.groq.com/openai/v1"
}
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.GROQ