Add gemini audio input support + handle special tokens in sagemaker response (#9640)

* fix(internal_user_endpoints.py): cleanup unused variables on beta endpoint

no team/org split on daily user endpoint

* build(model_prices_and_context_window.json): gemini-2.0-flash supports audio input

* feat(gemini/transformation.py): support passing audio input to gemini

* test: fix test

* fix(gemini/transformation.py): support audio input as a url

enables passing google cloud bucket urls

* fix(gemini/transformation.py): support explicitly passing format of file

* fix(gemini/transformation.py): expand support for inferred file types from url

* fix(sagemaker/completion/transformation.py): fix special token error when counting sagemaker tokens

* test: fix import
This commit is contained in:
Krish Dholakia 2025-03-29 19:23:09 -07:00 committed by GitHub
parent af885be743
commit 70f993d3d7
11 changed files with 191 additions and 58 deletions

View file

@ -19,6 +19,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import SagemakerError
@ -238,9 +239,12 @@ class SagemakerConfig(BaseConfig):
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
prompt_tokens = token_counter(
text=prompt, count_response_tokens=True
) # doesn't apply any default token count from openai's chat template
completion_tokens = token_counter(
text=model_response["choices"][0]["message"].get("content", ""),
count_response_tokens=True,
)
model_response.created = int(time.time())

View file

@ -27,6 +27,8 @@ from litellm.types.files import (
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAudioObject,
ChatCompletionFileObject,
ChatCompletionImageObject,
ChatCompletionTextObject,
)
@ -103,24 +105,53 @@ def _get_image_mime_type_from_url(url: str) -> Optional[str]:
See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements
Supported by Gemini:
- PNG (`image/png`)
- JPEG (`image/jpeg`)
- WebP (`image/webp`)
Example:
url = https://example.com/image.jpg
Returns: image/jpeg
application/pdf
audio/mpeg
audio/mp3
audio/wav
image/png
image/jpeg
image/webp
text/plain
video/mov
video/mpeg
video/mp4
video/mpg
video/avi
video/wmv
video/mpegps
video/flv
"""
url = url.lower()
if url.endswith((".jpg", ".jpeg")):
return "image/jpeg"
elif url.endswith(".png"):
return "image/png"
elif url.endswith(".webp"):
return "image/webp"
elif url.endswith(".mp4"):
return "video/mp4"
elif url.endswith(".pdf"):
return "application/pdf"
# Map file extensions to mime types
mime_types = {
# Images
(".jpg", ".jpeg"): "image/jpeg",
(".png",): "image/png",
(".webp",): "image/webp",
# Videos
(".mp4",): "video/mp4",
(".mov",): "video/mov",
(".mpeg", ".mpg"): "video/mpeg",
(".avi",): "video/avi",
(".wmv",): "video/wmv",
(".mpegps",): "video/mpegps",
(".flv",): "video/flv",
# Audio
(".mp3",): "audio/mp3",
(".wav",): "audio/wav",
(".mpeg",): "audio/mpeg",
# Documents
(".pdf",): "application/pdf",
(".txt",): "text/plain",
}
# Check each extension group against the URL
for extensions, mime_type in mime_types.items():
if any(url.endswith(ext) for ext in extensions):
return mime_type
return None
@ -152,7 +183,7 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
_message_content = messages[msg_i].get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
for element_idx, element in enumerate(_message_content):
if (
element["type"] == "text"
and "text" in element
@ -174,6 +205,41 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
image_url=image_url, format=format
)
_parts.append(_part)
elif element["type"] == "input_audio":
audio_element = cast(ChatCompletionAudioObject, element)
if audio_element["input_audio"].get("data") is not None:
_part = PartType(
inline_data=BlobType(
data=audio_element["input_audio"]["data"],
mime_type="audio/{}".format(
audio_element["input_audio"]["format"]
),
)
)
_parts.append(_part)
elif element["type"] == "file":
file_element = cast(ChatCompletionFileObject, element)
file_id = file_element["file"].get("file_id")
format = file_element["file"].get("format")
if not file_id:
continue
mime_type = format or _get_image_mime_type_from_url(file_id)
if mime_type is not None:
_part = PartType(
file_data=FileDataType(
file_uri=file_id,
mime_type=mime_type,
)
)
_parts.append(_part)
else:
raise Exception(
"Unable to determine mime type for file_id: {}, set this explicitly using message[{}].content[{}].file.format".format(
file_id, msg_i, element_idx
)
)
user_content.extend(_parts)
elif (
_message_content is not None

View file

@ -4696,6 +4696,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supports_audio_input": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash"
},

View file

@ -16,14 +16,18 @@ model_list:
- model_name: "bedrock-nova"
litellm_params:
model: us.amazon.nova-pro-v1:0
- model_name: "gemini-2.0-flash"
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
litellm_settings:
num_retries: 0
callbacks: ["prometheus"]
json_logs: true
# json_logs: true
router_settings:
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
# router_settings:
# routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
# redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT

View file

@ -1370,22 +1370,6 @@ async def get_user_daily_activity(
default=None,
description="End date in YYYY-MM-DD format",
),
group_by: List[GroupByDimension] = fastapi.Query(
default=[GroupByDimension.DATE],
description="Dimensions to group by. Can combine multiple (e.g. date,team)",
),
view_by: Literal["team", "organization", "user"] = fastapi.Query(
default="user",
description="View spend at team/org/user level",
),
team_id: Optional[str] = fastapi.Query(
default=None,
description="Filter by specific team",
),
org_id: Optional[str] = fastapi.Query(
default=None,
description="Filter by specific organization",
),
model: Optional[str] = fastapi.Query(
default=None,
description="Filter by specific model",
@ -1408,13 +1392,13 @@ async def get_user_daily_activity(
Meant to optimize querying spend data for analytics for a user.
Returns:
(by date/team/org/user/model/api_key/model_group/provider)
(by date)
- spend
- prompt_tokens
- completion_tokens
- total_tokens
- api_requests
- breakdown by team, organization, user, model, api_key, model_group, provider
- breakdown by model, api_key, provider
"""
from litellm.proxy.proxy_server import prisma_client
@ -1439,10 +1423,6 @@ async def get_user_daily_activity(
}
}
if team_id:
where_conditions["team_id"] = team_id
if org_id:
where_conditions["organization_id"] = org_id
if model:
where_conditions["model"] = model
if api_key:

View file

@ -505,10 +505,11 @@ class ChatCompletionDocumentObject(TypedDict):
citations: Optional[CitationsObject]
class ChatCompletionFileObjectFile(TypedDict):
file_data: Optional[str]
file_id: Optional[str]
filename: Optional[str]
class ChatCompletionFileObjectFile(TypedDict, total=False):
file_data: str
file_id: str
filename: str
format: str
class ChatCompletionFileObject(TypedDict):

View file

@ -4696,6 +4696,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supports_audio_input": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash"
},

View file

@ -7,6 +7,7 @@ from unittest.mock import MagicMock, Mock, patch
import os
import uuid
import time
import base64
sys.path.insert(
0, os.path.abspath("../..")
@ -889,6 +890,74 @@ class BaseLLMChatTest(ABC):
assert cost > 0
@pytest.mark.parametrize("input_type", ["input_audio", "audio_url"])
def test_supports_audio_input(self, input_type):
from litellm.utils import return_raw_request, supports_audio_input
from litellm.types.utils import CallTypes
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
litellm.drop_params = True
base_completion_call_args = self.get_base_completion_call_args()
if not supports_audio_input(base_completion_call_args["model"], None):
print("Model does not support audio input")
pytest.skip("Model does not support audio input")
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
response = httpx.get(url)
response.raise_for_status()
wav_data = response.content
audio_format = "wav"
encoded_string = base64.b64encode(wav_data).decode("utf-8")
audio_content = [
{
"type": "text",
"text": "What is in this recording?"
}
]
test_file_id = "gs://bucket/file.wav"
if input_type == "input_audio":
audio_content.append({
"type": "input_audio",
"input_audio": {"data": encoded_string, "format": audio_format},
})
elif input_type == "audio_url":
audio_content.append(
{
"type": "file",
"file": {
"file_id": test_file_id,
"filename": "my-sample-audio-file",
}
}
)
raw_request = return_raw_request(
endpoint=CallTypes.completion,
kwargs={
**base_completion_call_args,
"modalities": ["text", "audio"],
"audio": {"voice": "alloy", "format": audio_format},
"messages": [
{
"role": "user",
"content": audio_content,
},
]
}
)
print("raw_request: ", raw_request)
if input_type == "input_audio":
assert encoded_string in json.dumps(raw_request), "Audio data not sent to gemini"
elif input_type == "audio_url":
assert test_file_id in json.dumps(raw_request), "Audio URL not sent to gemini"
class BaseOSeriesModelsTest(ABC): # test across azure/openai
@abstractmethod
@ -1089,3 +1158,5 @@ class BaseAnthropicChatTest(ABC):
)
print(response)

View file

@ -15,7 +15,7 @@ from litellm.llms.vertex_ai.context_caching.transformation import (
class TestGoogleAIStudioGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash-002"}
return {"model": "gemini/gemini-2.0-flash"}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""

View file

@ -84,12 +84,14 @@ async def test_audio_output_from_model(stream):
@pytest.mark.asyncio
@pytest.mark.parametrize("stream", [True, False])
async def test_audio_input_to_model(stream):
@pytest.mark.parametrize("model", ["gpt-4o-audio-preview"]) # "gpt-4o-audio-preview",
async def test_audio_input_to_model(stream, model):
# Fetch the audio file and convert it to a base64 encoded string
audio_format = "pcm16"
if stream is False:
audio_format = "wav"
litellm._turn_on_debug()
litellm.drop_params = True
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
response = requests.get(url)
response.raise_for_status()
@ -97,7 +99,7 @@ async def test_audio_input_to_model(stream):
encoded_string = base64.b64encode(wav_data).decode("utf-8")
try:
completion = await litellm.acompletion(
model="gpt-4o-audio-preview",
model=model,
modalities=["text", "audio"],
audio={"voice": "alloy", "format": audio_format},
stream=stream,
@ -120,6 +122,7 @@ async def test_audio_input_to_model(stream):
except Exception as e:
if "openai-internal" in str(e):
pytest.skip("Skipping test due to openai-internal error")
raise e
if stream is True:
await check_streaming_response(completion)
else:

View file

@ -2285,10 +2285,10 @@ def test_update_logs_with_spend_logs_url(prisma_client):
"""
Unit test for making sure spend logs list is still updated when url passed in
"""
from litellm.proxy.proxy_server import _set_spend_logs_payload
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
_set_spend_logs_payload(payload=payload, prisma_client=prisma_client)
DBSpendUpdateWriter._set_spend_logs_payload(payload=payload, prisma_client=prisma_client)
assert len(prisma_client.spend_log_transactions) > 0
@ -2296,7 +2296,7 @@ def test_update_logs_with_spend_logs_url(prisma_client):
spend_logs_url = ""
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
_set_spend_logs_payload(
DBSpendUpdateWriter._set_spend_logs_payload(
payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client
)