mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(gemini/transformation.py): handle file_data being passed in (#9786)
This commit is contained in:
parent
0307a0133b
commit
9a60cd9deb
2 changed files with 22 additions and 16 deletions
|
@ -208,25 +208,24 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
|||
elif element["type"] == "input_audio":
|
||||
audio_element = cast(ChatCompletionAudioObject, element)
|
||||
if audio_element["input_audio"].get("data") is not None:
|
||||
_part = PartType(
|
||||
inline_data=BlobType(
|
||||
data=audio_element["input_audio"]["data"],
|
||||
mime_type="audio/{}".format(
|
||||
audio_element["input_audio"]["format"]
|
||||
),
|
||||
)
|
||||
_part = _process_gemini_image(
|
||||
image_url=audio_element["input_audio"]["data"],
|
||||
format=audio_element["input_audio"].get("format"),
|
||||
)
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "file":
|
||||
file_element = cast(ChatCompletionFileObject, element)
|
||||
file_id = file_element["file"].get("file_id")
|
||||
format = file_element["file"].get("format")
|
||||
|
||||
if not file_id:
|
||||
continue
|
||||
file_data = file_element["file"].get("file_data")
|
||||
passed_file = file_id or file_data
|
||||
if passed_file is None:
|
||||
raise Exception(
|
||||
"Unknown file type. Please pass in a file_id or file_data"
|
||||
)
|
||||
try:
|
||||
_part = _process_gemini_image(
|
||||
image_url=file_id, format=format
|
||||
image_url=passed_file, format=format
|
||||
)
|
||||
_parts.append(_part)
|
||||
except Exception:
|
||||
|
|
|
@ -928,7 +928,8 @@ class BaseLLMChatTest(ABC):
|
|||
assert response._hidden_params["response_cost"] > 0
|
||||
|
||||
@pytest.mark.parametrize("input_type", ["input_audio", "audio_url"])
|
||||
def test_supports_audio_input(self, input_type):
|
||||
@pytest.mark.parametrize("format_specified", [True, False])
|
||||
def test_supports_audio_input(self, input_type, format_specified):
|
||||
from litellm.utils import return_raw_request, supports_audio_input
|
||||
from litellm.types.utils import CallTypes
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
|
@ -958,10 +959,16 @@ class BaseLLMChatTest(ABC):
|
|||
test_file_id = "gs://bucket/file.wav"
|
||||
|
||||
if input_type == "input_audio":
|
||||
audio_content.append({
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": encoded_string, "format": audio_format},
|
||||
})
|
||||
if format_specified:
|
||||
audio_content.append({
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": encoded_string, "format": audio_format},
|
||||
})
|
||||
else:
|
||||
audio_content.append({
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": encoded_string},
|
||||
})
|
||||
elif input_type == "audio_url":
|
||||
audio_content.append(
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue