From e89dcccdd9f8cf95b0519385c07e540b14c864ab Mon Sep 17 00:00:00 2001 From: Show <35062952+BrunooShow@users.noreply.github.com> Date: Wed, 20 Nov 2024 00:04:33 +0100 Subject: [PATCH] (feat): Add timestamp_granularities parameter to transcription API (#6457) * Add timestamp_granularities parameter to transcription API * add param to the local test --- litellm/main.py | 2 ++ litellm/utils.py | 1 + tests/local_testing/test_whisper.py | 4 +++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/litellm/main.py b/litellm/main.py index 3b4a99413..0afce3db5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -4729,6 +4729,7 @@ def transcription( response_format: Optional[ Literal["json", "text", "srt", "verbose_json", "vtt"] ] = None, + timestamp_granularities: List[Literal["word", "segment"]] = None, temperature: Optional[int] = None, # openai defaults this to 0 ## LITELLM PARAMS ## user: Optional[str] = None, @@ -4778,6 +4779,7 @@ def transcription( language=language, prompt=prompt, response_format=response_format, + timestamp_granularities=timestamp_granularities, temperature=temperature, custom_llm_provider=custom_llm_provider, drop_params=drop_params, diff --git a/litellm/utils.py b/litellm/utils.py index cb8a53354..2dce9db89 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2125,6 +2125,7 @@ def get_optional_params_transcription( prompt: Optional[str] = None, response_format: Optional[str] = None, temperature: Optional[int] = None, + timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None, custom_llm_provider: Optional[str] = None, drop_params: Optional[bool] = None, **kwargs, diff --git a/tests/local_testing/test_whisper.py b/tests/local_testing/test_whisper.py index f66ad8b13..7d5d0d710 100644 --- a/tests/local_testing/test_whisper.py +++ b/tests/local_testing/test_whisper.py @@ -53,8 +53,9 @@ from litellm import Router ) @pytest.mark.parametrize("response_format", ["json", "vtt"]) @pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("timestamp_granularities", [["word"], ["segment"]]) @pytest.mark.asyncio -async def test_transcription(model, api_key, api_base, response_format, sync_mode): +async def test_transcription(model, api_key, api_base, response_format, sync_mode, timestamp_granularities): if sync_mode: transcript = litellm.transcription( model=model, @@ -62,6 +63,7 @@ async def test_transcription(model, api_key, api_base, response_format, sync_mod api_key=api_key, api_base=api_base, response_format=response_format, + timestamp_granularities=timestamp_granularities, drop_params=True, ) else: