mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge fcd2586909
into b82af5b826
This commit is contained in:
commit
2c5b9e07f0
5 changed files with 32 additions and 98 deletions
|
@ -828,7 +828,7 @@ from .llms.maritalk import MaritalkConfig
|
|||
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
||||
from .llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from .llms.anthropic.common_utils import AnthropicModelInfo
|
||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||
from .llms.groq.stt.transformation import GroqAudioTranscriptionConfig
|
||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||
from .llms.triton.completion.transformation import TritonConfig
|
||||
from .llms.triton.completion.transformation import TritonGenerateConfig
|
||||
|
|
|
@ -79,7 +79,10 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
elif custom_llm_provider == "volcengine":
|
||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "groq":
|
||||
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
||||
if request_type == "transcription":
|
||||
return litellm.GroqAudioTranscriptionConfig().get_supported_openai_params(model=model)
|
||||
else:
|
||||
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "hosted_vllm":
|
||||
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "vllm":
|
||||
|
|
|
@ -2,99 +2,16 @@
|
|||
Translate from OpenAI's `/v1/audio/transcriptions` to Groq's `/v1/audio/transcriptions`
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
from typing import List
|
||||
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
|
||||
|
||||
import litellm
|
||||
from ...openai.transcriptions.whisper_transformation import (
|
||||
OpenAIWhisperAudioTranscriptionConfig,
|
||||
)
|
||||
|
||||
class GroqAudioTranscriptionConfig(OpenAIWhisperAudioTranscriptionConfig):
|
||||
|
||||
class GroqSTTConfig:
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
tools: Optional[list] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params_stt(self):
|
||||
return [
|
||||
"prompt",
|
||||
"response_format",
|
||||
"temperature",
|
||||
"language",
|
||||
]
|
||||
|
||||
def get_supported_openai_response_formats_stt(self) -> List[str]:
|
||||
return ["json", "verbose_json", "text"]
|
||||
|
||||
def map_openai_params_stt(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
response_formats = self.get_supported_openai_response_formats_stt()
|
||||
for param, value in non_default_params.items():
|
||||
if param == "response_format":
|
||||
if value in response_formats:
|
||||
optional_params[param] = value
|
||||
else:
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Groq doesn't support response_format={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
||||
value
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
else:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
return ["language", "prompt", "response_format", "temperature", "timestamp_granularities"]
|
||||
|
|
|
@ -2434,9 +2434,9 @@ def get_optional_params_transcription(
|
|||
if custom_llm_provider == "openai" or custom_llm_provider == "azure":
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "groq":
|
||||
supported_params = litellm.GroqSTTConfig().get_supported_openai_params_stt()
|
||||
supported_params = litellm.GroqAudioTranscriptionConfig().get_supported_openai_params(model=model)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.GroqSTTConfig().map_openai_params_stt(
|
||||
optional_params = litellm.GroqAudioTranscriptionConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
|
@ -6629,6 +6629,8 @@ class ProviderConfigManager:
|
|||
return litellm.OpenAIGPTAudioTranscriptionConfig()
|
||||
else:
|
||||
return litellm.OpenAIWhisperAudioTranscriptionConfig()
|
||||
elif litellm.LlmProviders.GROQ == provider:
|
||||
return litellm.GroqAudioTranscriptionConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
from base_audio_transcription_unit_tests import BaseLLMAudioTranscriptionTest
|
||||
import litellm
|
||||
|
||||
class TestGroq(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
|
@ -10,3 +11,14 @@ class TestGroq(BaseLLMChatTest):
|
|||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
|
||||
class TestGroqAudioTranscription(BaseLLMAudioTranscriptionTest):
|
||||
def get_base_audio_transcription_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "groq/whisper-large-v3",
|
||||
"api_base": "https://api.groq.com/openai/v1"
|
||||
}
|
||||
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
return litellm.LlmProviders.GROQ
|
Loading…
Add table
Add a link
Reference in a new issue