Add gemini audio input support + handle special tokens in sagemaker response (#9640)

* fix(internal_user_endpoints.py): cleanup unused variables on beta endpoint

no team/org split on daily user endpoint

* build(model_prices_and_context_window.json): gemini-2.0-flash supports audio input

* feat(gemini/transformation.py): support passing audio input to gemini

* test: fix test

* fix(gemini/transformation.py): support audio input as a url

enables passing google cloud bucket urls

* fix(gemini/transformation.py): support explicitly passing format of file

* fix(gemini/transformation.py): expand support for inferred file types from url

* fix(sagemaker/completion/transformation.py): fix special token error when counting sagemaker tokens

* test: fix import
This commit is contained in:
Krish Dholakia 2025-03-29 19:23:09 -07:00 committed by GitHub
parent af885be743
commit 70f993d3d7
11 changed files with 191 additions and 58 deletions

View file

@ -19,6 +19,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage from litellm.types.utils import ModelResponse, Usage
from litellm.utils import token_counter
from ..common_utils import SagemakerError from ..common_utils import SagemakerError
@ -238,9 +239,12 @@ class SagemakerConfig(BaseConfig):
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = token_counter(
completion_tokens = len( text=prompt, count_response_tokens=True
encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) # doesn't apply any default token count from openai's chat template
completion_tokens = token_counter(
text=model_response["choices"][0]["message"].get("content", ""),
count_response_tokens=True,
) )
model_response.created = int(time.time()) model_response.created = int(time.time())

View file

@ -27,6 +27,8 @@ from litellm.types.files import (
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionAssistantMessage, ChatCompletionAssistantMessage,
ChatCompletionAudioObject,
ChatCompletionFileObject,
ChatCompletionImageObject, ChatCompletionImageObject,
ChatCompletionTextObject, ChatCompletionTextObject,
) )
@ -103,24 +105,53 @@ def _get_image_mime_type_from_url(url: str) -> Optional[str]:
See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements
Supported by Gemini: Supported by Gemini:
- PNG (`image/png`) application/pdf
- JPEG (`image/jpeg`) audio/mpeg
- WebP (`image/webp`) audio/mp3
Example: audio/wav
url = https://example.com/image.jpg image/png
Returns: image/jpeg image/jpeg
image/webp
text/plain
video/mov
video/mpeg
video/mp4
video/mpg
video/avi
video/wmv
video/mpegps
video/flv
""" """
url = url.lower() url = url.lower()
if url.endswith((".jpg", ".jpeg")):
return "image/jpeg" # Map file extensions to mime types
elif url.endswith(".png"): mime_types = {
return "image/png" # Images
elif url.endswith(".webp"): (".jpg", ".jpeg"): "image/jpeg",
return "image/webp" (".png",): "image/png",
elif url.endswith(".mp4"): (".webp",): "image/webp",
return "video/mp4" # Videos
elif url.endswith(".pdf"): (".mp4",): "video/mp4",
return "application/pdf" (".mov",): "video/mov",
(".mpeg", ".mpg"): "video/mpeg",
(".avi",): "video/avi",
(".wmv",): "video/wmv",
(".mpegps",): "video/mpegps",
(".flv",): "video/flv",
# Audio
(".mp3",): "audio/mp3",
(".wav",): "audio/wav",
(".mpeg",): "audio/mpeg",
# Documents
(".pdf",): "application/pdf",
(".txt",): "text/plain",
}
# Check each extension group against the URL
for extensions, mime_type in mime_types.items():
if any(url.endswith(ext) for ext in extensions):
return mime_type
return None return None
@ -152,7 +183,7 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
_message_content = messages[msg_i].get("content") _message_content = messages[msg_i].get("content")
if _message_content is not None and isinstance(_message_content, list): if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = [] _parts: List[PartType] = []
for element in _message_content: for element_idx, element in enumerate(_message_content):
if ( if (
element["type"] == "text" element["type"] == "text"
and "text" in element and "text" in element
@ -174,6 +205,41 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
image_url=image_url, format=format image_url=image_url, format=format
) )
_parts.append(_part) _parts.append(_part)
elif element["type"] == "input_audio":
audio_element = cast(ChatCompletionAudioObject, element)
if audio_element["input_audio"].get("data") is not None:
_part = PartType(
inline_data=BlobType(
data=audio_element["input_audio"]["data"],
mime_type="audio/{}".format(
audio_element["input_audio"]["format"]
),
)
)
_parts.append(_part)
elif element["type"] == "file":
file_element = cast(ChatCompletionFileObject, element)
file_id = file_element["file"].get("file_id")
format = file_element["file"].get("format")
if not file_id:
continue
mime_type = format or _get_image_mime_type_from_url(file_id)
if mime_type is not None:
_part = PartType(
file_data=FileDataType(
file_uri=file_id,
mime_type=mime_type,
)
)
_parts.append(_part)
else:
raise Exception(
"Unable to determine mime type for file_id: {}, set this explicitly using message[{}].content[{}].file.format".format(
file_id, msg_i, element_idx
)
)
user_content.extend(_parts) user_content.extend(_parts)
elif ( elif (
_message_content is not None _message_content is not None

View file

@ -4696,6 +4696,8 @@
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_audio_output": true, "supports_audio_output": true,
"supports_audio_input": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supports_tool_choice": true, "supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash" "source": "https://ai.google.dev/pricing#2_0flash"
}, },

View file

@ -16,14 +16,18 @@ model_list:
- model_name: "bedrock-nova" - model_name: "bedrock-nova"
litellm_params: litellm_params:
model: us.amazon.nova-pro-v1:0 model: us.amazon.nova-pro-v1:0
- model_name: "gemini-2.0-flash"
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
litellm_settings: litellm_settings:
num_retries: 0 num_retries: 0
callbacks: ["prometheus"] callbacks: ["prometheus"]
json_logs: true # json_logs: true
router_settings: # router_settings:
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE # routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
redis_host: os.environ/REDIS_HOST # redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD # redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT # redis_port: os.environ/REDIS_PORT

View file

@ -1370,22 +1370,6 @@ async def get_user_daily_activity(
default=None, default=None,
description="End date in YYYY-MM-DD format", description="End date in YYYY-MM-DD format",
), ),
group_by: List[GroupByDimension] = fastapi.Query(
default=[GroupByDimension.DATE],
description="Dimensions to group by. Can combine multiple (e.g. date,team)",
),
view_by: Literal["team", "organization", "user"] = fastapi.Query(
default="user",
description="View spend at team/org/user level",
),
team_id: Optional[str] = fastapi.Query(
default=None,
description="Filter by specific team",
),
org_id: Optional[str] = fastapi.Query(
default=None,
description="Filter by specific organization",
),
model: Optional[str] = fastapi.Query( model: Optional[str] = fastapi.Query(
default=None, default=None,
description="Filter by specific model", description="Filter by specific model",
@ -1408,13 +1392,13 @@ async def get_user_daily_activity(
Meant to optimize querying spend data for analytics for a user. Meant to optimize querying spend data for analytics for a user.
Returns: Returns:
(by date/team/org/user/model/api_key/model_group/provider) (by date)
- spend - spend
- prompt_tokens - prompt_tokens
- completion_tokens - completion_tokens
- total_tokens - total_tokens
- api_requests - api_requests
- breakdown by team, organization, user, model, api_key, model_group, provider - breakdown by model, api_key, provider
""" """
from litellm.proxy.proxy_server import prisma_client from litellm.proxy.proxy_server import prisma_client
@ -1439,10 +1423,6 @@ async def get_user_daily_activity(
} }
} }
if team_id:
where_conditions["team_id"] = team_id
if org_id:
where_conditions["organization_id"] = org_id
if model: if model:
where_conditions["model"] = model where_conditions["model"] = model
if api_key: if api_key:

View file

@ -505,10 +505,11 @@ class ChatCompletionDocumentObject(TypedDict):
citations: Optional[CitationsObject] citations: Optional[CitationsObject]
class ChatCompletionFileObjectFile(TypedDict): class ChatCompletionFileObjectFile(TypedDict, total=False):
file_data: Optional[str] file_data: str
file_id: Optional[str] file_id: str
filename: Optional[str] filename: str
format: str
class ChatCompletionFileObject(TypedDict): class ChatCompletionFileObject(TypedDict):

View file

@ -4696,6 +4696,8 @@
"supports_vision": true, "supports_vision": true,
"supports_response_schema": true, "supports_response_schema": true,
"supports_audio_output": true, "supports_audio_output": true,
"supports_audio_input": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supports_tool_choice": true, "supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash" "source": "https://ai.google.dev/pricing#2_0flash"
}, },

View file

@ -7,6 +7,7 @@ from unittest.mock import MagicMock, Mock, patch
import os import os
import uuid import uuid
import time import time
import base64
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -889,6 +890,74 @@ class BaseLLMChatTest(ABC):
assert cost > 0 assert cost > 0
@pytest.mark.parametrize("input_type", ["input_audio", "audio_url"])
def test_supports_audio_input(self, input_type):
from litellm.utils import return_raw_request, supports_audio_input
from litellm.types.utils import CallTypes
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
litellm.drop_params = True
base_completion_call_args = self.get_base_completion_call_args()
if not supports_audio_input(base_completion_call_args["model"], None):
print("Model does not support audio input")
pytest.skip("Model does not support audio input")
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
response = httpx.get(url)
response.raise_for_status()
wav_data = response.content
audio_format = "wav"
encoded_string = base64.b64encode(wav_data).decode("utf-8")
audio_content = [
{
"type": "text",
"text": "What is in this recording?"
}
]
test_file_id = "gs://bucket/file.wav"
if input_type == "input_audio":
audio_content.append({
"type": "input_audio",
"input_audio": {"data": encoded_string, "format": audio_format},
})
elif input_type == "audio_url":
audio_content.append(
{
"type": "file",
"file": {
"file_id": test_file_id,
"filename": "my-sample-audio-file",
}
}
)
raw_request = return_raw_request(
endpoint=CallTypes.completion,
kwargs={
**base_completion_call_args,
"modalities": ["text", "audio"],
"audio": {"voice": "alloy", "format": audio_format},
"messages": [
{
"role": "user",
"content": audio_content,
},
]
}
)
print("raw_request: ", raw_request)
if input_type == "input_audio":
assert encoded_string in json.dumps(raw_request), "Audio data not sent to gemini"
elif input_type == "audio_url":
assert test_file_id in json.dumps(raw_request), "Audio URL not sent to gemini"
class BaseOSeriesModelsTest(ABC): # test across azure/openai class BaseOSeriesModelsTest(ABC): # test across azure/openai
@abstractmethod @abstractmethod
@ -1089,3 +1158,5 @@ class BaseAnthropicChatTest(ABC):
) )
print(response) print(response)

View file

@ -15,7 +15,7 @@ from litellm.llms.vertex_ai.context_caching.transformation import (
class TestGoogleAIStudioGemini(BaseLLMChatTest): class TestGoogleAIStudioGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict: def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash-002"} return {"model": "gemini/gemini-2.0-flash"}
def test_tool_call_no_arguments(self, tool_call_no_arguments): def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""

View file

@ -84,12 +84,14 @@ async def test_audio_output_from_model(stream):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("stream", [True, False])
async def test_audio_input_to_model(stream): @pytest.mark.parametrize("model", ["gpt-4o-audio-preview"]) # "gpt-4o-audio-preview",
async def test_audio_input_to_model(stream, model):
# Fetch the audio file and convert it to a base64 encoded string # Fetch the audio file and convert it to a base64 encoded string
audio_format = "pcm16" audio_format = "pcm16"
if stream is False: if stream is False:
audio_format = "wav" audio_format = "wav"
litellm._turn_on_debug() litellm._turn_on_debug()
litellm.drop_params = True
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav" url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
response = requests.get(url) response = requests.get(url)
response.raise_for_status() response.raise_for_status()
@ -97,7 +99,7 @@ async def test_audio_input_to_model(stream):
encoded_string = base64.b64encode(wav_data).decode("utf-8") encoded_string = base64.b64encode(wav_data).decode("utf-8")
try: try:
completion = await litellm.acompletion( completion = await litellm.acompletion(
model="gpt-4o-audio-preview", model=model,
modalities=["text", "audio"], modalities=["text", "audio"],
audio={"voice": "alloy", "format": audio_format}, audio={"voice": "alloy", "format": audio_format},
stream=stream, stream=stream,
@ -120,6 +122,7 @@ async def test_audio_input_to_model(stream):
except Exception as e: except Exception as e:
if "openai-internal" in str(e): if "openai-internal" in str(e):
pytest.skip("Skipping test due to openai-internal error") pytest.skip("Skipping test due to openai-internal error")
raise e
if stream is True: if stream is True:
await check_streaming_response(completion) await check_streaming_response(completion)
else: else:

View file

@ -2285,10 +2285,10 @@ def test_update_logs_with_spend_logs_url(prisma_client):
""" """
Unit test for making sure spend logs list is still updated when url passed in Unit test for making sure spend logs list is still updated when url passed in
""" """
from litellm.proxy.proxy_server import _set_spend_logs_payload from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
payload = {"startTime": datetime.now(), "endTime": datetime.now()} payload = {"startTime": datetime.now(), "endTime": datetime.now()}
_set_spend_logs_payload(payload=payload, prisma_client=prisma_client) DBSpendUpdateWriter._set_spend_logs_payload(payload=payload, prisma_client=prisma_client)
assert len(prisma_client.spend_log_transactions) > 0 assert len(prisma_client.spend_log_transactions) > 0
@ -2296,7 +2296,7 @@ def test_update_logs_with_spend_logs_url(prisma_client):
spend_logs_url = "" spend_logs_url = ""
payload = {"startTime": datetime.now(), "endTime": datetime.now()} payload = {"startTime": datetime.now(), "endTime": datetime.now()}
_set_spend_logs_payload( DBSpendUpdateWriter._set_spend_logs_payload(
payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client
) )