mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Add gemini audio input support + handle special tokens in sagemaker response (#9640)
* fix(internal_user_endpoints.py): cleanup unused variables on beta endpoint no team/org split on daily user endpoint * build(model_prices_and_context_window.json): gemini-2.0-flash supports audio input * feat(gemini/transformation.py): support passing audio input to gemini * test: fix test * fix(gemini/transformation.py): support audio input as a url enables passing google cloud bucket urls * fix(gemini/transformation.py): support explicitly passing format of file * fix(gemini/transformation.py): expand support for inferred file types from url * fix(sagemaker/completion/transformation.py): fix special token error when counting sagemaker tokens * test: fix import
This commit is contained in:
parent
af885be743
commit
70f993d3d7
11 changed files with 191 additions and 58 deletions
|
@ -7,6 +7,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||
import os
|
||||
import uuid
|
||||
import time
|
||||
import base64
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -889,6 +890,74 @@ class BaseLLMChatTest(ABC):
|
|||
|
||||
assert cost > 0
|
||||
|
||||
@pytest.mark.parametrize("input_type", ["input_audio", "audio_url"])
|
||||
def test_supports_audio_input(self, input_type):
|
||||
from litellm.utils import return_raw_request, supports_audio_input
|
||||
from litellm.types.utils import CallTypes
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
|
||||
litellm.drop_params = True
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_audio_input(base_completion_call_args["model"], None):
|
||||
print("Model does not support audio input")
|
||||
pytest.skip("Model does not support audio input")
|
||||
|
||||
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
|
||||
response = httpx.get(url)
|
||||
response.raise_for_status()
|
||||
wav_data = response.content
|
||||
audio_format = "wav"
|
||||
encoded_string = base64.b64encode(wav_data).decode("utf-8")
|
||||
|
||||
audio_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this recording?"
|
||||
}
|
||||
]
|
||||
|
||||
test_file_id = "gs://bucket/file.wav"
|
||||
|
||||
if input_type == "input_audio":
|
||||
audio_content.append({
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": encoded_string, "format": audio_format},
|
||||
})
|
||||
elif input_type == "audio_url":
|
||||
audio_content.append(
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"file_id": test_file_id,
|
||||
"filename": "my-sample-audio-file",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
raw_request = return_raw_request(
|
||||
endpoint=CallTypes.completion,
|
||||
kwargs={
|
||||
**base_completion_call_args,
|
||||
"modalities": ["text", "audio"],
|
||||
"audio": {"voice": "alloy", "format": audio_format},
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": audio_content,
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
print("raw_request: ", raw_request)
|
||||
|
||||
if input_type == "input_audio":
|
||||
assert encoded_string in json.dumps(raw_request), "Audio data not sent to gemini"
|
||||
elif input_type == "audio_url":
|
||||
assert test_file_id in json.dumps(raw_request), "Audio URL not sent to gemini"
|
||||
|
||||
class BaseOSeriesModelsTest(ABC): # test across azure/openai
|
||||
@abstractmethod
|
||||
|
@ -1089,3 +1158,5 @@ class BaseAnthropicChatTest(ABC):
|
|||
)
|
||||
|
||||
print(response)
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue