mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Add gemini audio input support + handle special tokens in sagemaker response (#9640)
* fix(internal_user_endpoints.py): cleanup unused variables on beta endpoint no team/org split on daily user endpoint * build(model_prices_and_context_window.json): gemini-2.0-flash supports audio input * feat(gemini/transformation.py): support passing audio input to gemini * test: fix test * fix(gemini/transformation.py): support audio input as a url enables passing google cloud bucket urls * fix(gemini/transformation.py): support explicitly passing format of file * fix(gemini/transformation.py): expand support for inferred file types from url * fix(sagemaker/completion/transformation.py): fix special token error when counting sagemaker tokens * test: fix import
This commit is contained in:
parent
af885be743
commit
70f993d3d7
11 changed files with 191 additions and 58 deletions
|
@ -19,6 +19,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import token_counter
|
||||
|
||||
from ..common_utils import SagemakerError
|
||||
|
||||
|
@ -238,9 +239,12 @@ class SagemakerConfig(BaseConfig):
|
|||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
prompt_tokens = token_counter(
|
||||
text=prompt, count_response_tokens=True
|
||||
) # doesn't apply any default token count from openai's chat template
|
||||
completion_tokens = token_counter(
|
||||
text=model_response["choices"][0]["message"].get("content", ""),
|
||||
count_response_tokens=True,
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
|
|
|
@ -27,6 +27,8 @@ from litellm.types.files import (
|
|||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionAudioObject,
|
||||
ChatCompletionFileObject,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionTextObject,
|
||||
)
|
||||
|
@ -103,24 +105,53 @@ def _get_image_mime_type_from_url(url: str) -> Optional[str]:
|
|||
See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements
|
||||
|
||||
Supported by Gemini:
|
||||
- PNG (`image/png`)
|
||||
- JPEG (`image/jpeg`)
|
||||
- WebP (`image/webp`)
|
||||
Example:
|
||||
url = https://example.com/image.jpg
|
||||
Returns: image/jpeg
|
||||
application/pdf
|
||||
audio/mpeg
|
||||
audio/mp3
|
||||
audio/wav
|
||||
image/png
|
||||
image/jpeg
|
||||
image/webp
|
||||
text/plain
|
||||
video/mov
|
||||
video/mpeg
|
||||
video/mp4
|
||||
video/mpg
|
||||
video/avi
|
||||
video/wmv
|
||||
video/mpegps
|
||||
video/flv
|
||||
"""
|
||||
url = url.lower()
|
||||
if url.endswith((".jpg", ".jpeg")):
|
||||
return "image/jpeg"
|
||||
elif url.endswith(".png"):
|
||||
return "image/png"
|
||||
elif url.endswith(".webp"):
|
||||
return "image/webp"
|
||||
elif url.endswith(".mp4"):
|
||||
return "video/mp4"
|
||||
elif url.endswith(".pdf"):
|
||||
return "application/pdf"
|
||||
|
||||
# Map file extensions to mime types
|
||||
mime_types = {
|
||||
# Images
|
||||
(".jpg", ".jpeg"): "image/jpeg",
|
||||
(".png",): "image/png",
|
||||
(".webp",): "image/webp",
|
||||
# Videos
|
||||
(".mp4",): "video/mp4",
|
||||
(".mov",): "video/mov",
|
||||
(".mpeg", ".mpg"): "video/mpeg",
|
||||
(".avi",): "video/avi",
|
||||
(".wmv",): "video/wmv",
|
||||
(".mpegps",): "video/mpegps",
|
||||
(".flv",): "video/flv",
|
||||
# Audio
|
||||
(".mp3",): "audio/mp3",
|
||||
(".wav",): "audio/wav",
|
||||
(".mpeg",): "audio/mpeg",
|
||||
# Documents
|
||||
(".pdf",): "application/pdf",
|
||||
(".txt",): "text/plain",
|
||||
}
|
||||
|
||||
# Check each extension group against the URL
|
||||
for extensions, mime_type in mime_types.items():
|
||||
if any(url.endswith(ext) for ext in extensions):
|
||||
return mime_type
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
@ -152,7 +183,7 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
|||
_message_content = messages[msg_i].get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts: List[PartType] = []
|
||||
for element in _message_content:
|
||||
for element_idx, element in enumerate(_message_content):
|
||||
if (
|
||||
element["type"] == "text"
|
||||
and "text" in element
|
||||
|
@ -174,6 +205,41 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
|||
image_url=image_url, format=format
|
||||
)
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "input_audio":
|
||||
audio_element = cast(ChatCompletionAudioObject, element)
|
||||
if audio_element["input_audio"].get("data") is not None:
|
||||
_part = PartType(
|
||||
inline_data=BlobType(
|
||||
data=audio_element["input_audio"]["data"],
|
||||
mime_type="audio/{}".format(
|
||||
audio_element["input_audio"]["format"]
|
||||
),
|
||||
)
|
||||
)
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "file":
|
||||
file_element = cast(ChatCompletionFileObject, element)
|
||||
file_id = file_element["file"].get("file_id")
|
||||
format = file_element["file"].get("format")
|
||||
|
||||
if not file_id:
|
||||
continue
|
||||
mime_type = format or _get_image_mime_type_from_url(file_id)
|
||||
|
||||
if mime_type is not None:
|
||||
_part = PartType(
|
||||
file_data=FileDataType(
|
||||
file_uri=file_id,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
_parts.append(_part)
|
||||
else:
|
||||
raise Exception(
|
||||
"Unable to determine mime type for file_id: {}, set this explicitly using message[{}].content[{}].file.format".format(
|
||||
file_id, msg_i, element_idx
|
||||
)
|
||||
)
|
||||
user_content.extend(_parts)
|
||||
elif (
|
||||
_message_content is not None
|
||||
|
|
|
@ -4696,6 +4696,8 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_audio_output": true,
|
||||
"supports_audio_input": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supports_tool_choice": true,
|
||||
"source": "https://ai.google.dev/pricing#2_0flash"
|
||||
},
|
||||
|
|
|
@ -16,14 +16,18 @@ model_list:
|
|||
- model_name: "bedrock-nova"
|
||||
litellm_params:
|
||||
model: us.amazon.nova-pro-v1:0
|
||||
- model_name: "gemini-2.0-flash"
|
||||
litellm_params:
|
||||
model: gemini/gemini-2.0-flash
|
||||
api_key: os.environ/GEMINI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
num_retries: 0
|
||||
callbacks: ["prometheus"]
|
||||
json_logs: true
|
||||
# json_logs: true
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
# router_settings:
|
||||
# routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
|
||||
# redis_host: os.environ/REDIS_HOST
|
||||
# redis_password: os.environ/REDIS_PASSWORD
|
||||
# redis_port: os.environ/REDIS_PORT
|
||||
|
|
|
@ -1370,22 +1370,6 @@ async def get_user_daily_activity(
|
|||
default=None,
|
||||
description="End date in YYYY-MM-DD format",
|
||||
),
|
||||
group_by: List[GroupByDimension] = fastapi.Query(
|
||||
default=[GroupByDimension.DATE],
|
||||
description="Dimensions to group by. Can combine multiple (e.g. date,team)",
|
||||
),
|
||||
view_by: Literal["team", "organization", "user"] = fastapi.Query(
|
||||
default="user",
|
||||
description="View spend at team/org/user level",
|
||||
),
|
||||
team_id: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="Filter by specific team",
|
||||
),
|
||||
org_id: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="Filter by specific organization",
|
||||
),
|
||||
model: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="Filter by specific model",
|
||||
|
@ -1408,13 +1392,13 @@ async def get_user_daily_activity(
|
|||
Meant to optimize querying spend data for analytics for a user.
|
||||
|
||||
Returns:
|
||||
(by date/team/org/user/model/api_key/model_group/provider)
|
||||
(by date)
|
||||
- spend
|
||||
- prompt_tokens
|
||||
- completion_tokens
|
||||
- total_tokens
|
||||
- api_requests
|
||||
- breakdown by team, organization, user, model, api_key, model_group, provider
|
||||
- breakdown by model, api_key, provider
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
|
@ -1439,10 +1423,6 @@ async def get_user_daily_activity(
|
|||
}
|
||||
}
|
||||
|
||||
if team_id:
|
||||
where_conditions["team_id"] = team_id
|
||||
if org_id:
|
||||
where_conditions["organization_id"] = org_id
|
||||
if model:
|
||||
where_conditions["model"] = model
|
||||
if api_key:
|
||||
|
|
|
@ -505,10 +505,11 @@ class ChatCompletionDocumentObject(TypedDict):
|
|||
citations: Optional[CitationsObject]
|
||||
|
||||
|
||||
class ChatCompletionFileObjectFile(TypedDict):
|
||||
file_data: Optional[str]
|
||||
file_id: Optional[str]
|
||||
filename: Optional[str]
|
||||
class ChatCompletionFileObjectFile(TypedDict, total=False):
|
||||
file_data: str
|
||||
file_id: str
|
||||
filename: str
|
||||
format: str
|
||||
|
||||
|
||||
class ChatCompletionFileObject(TypedDict):
|
||||
|
|
|
@ -4696,6 +4696,8 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_audio_output": true,
|
||||
"supports_audio_input": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supports_tool_choice": true,
|
||||
"source": "https://ai.google.dev/pricing#2_0flash"
|
||||
},
|
||||
|
|
|
@ -7,6 +7,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||
import os
|
||||
import uuid
|
||||
import time
|
||||
import base64
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -889,6 +890,74 @@ class BaseLLMChatTest(ABC):
|
|||
|
||||
assert cost > 0
|
||||
|
||||
@pytest.mark.parametrize("input_type", ["input_audio", "audio_url"])
|
||||
def test_supports_audio_input(self, input_type):
|
||||
from litellm.utils import return_raw_request, supports_audio_input
|
||||
from litellm.types.utils import CallTypes
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
|
||||
litellm.drop_params = True
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_audio_input(base_completion_call_args["model"], None):
|
||||
print("Model does not support audio input")
|
||||
pytest.skip("Model does not support audio input")
|
||||
|
||||
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
|
||||
response = httpx.get(url)
|
||||
response.raise_for_status()
|
||||
wav_data = response.content
|
||||
audio_format = "wav"
|
||||
encoded_string = base64.b64encode(wav_data).decode("utf-8")
|
||||
|
||||
audio_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this recording?"
|
||||
}
|
||||
]
|
||||
|
||||
test_file_id = "gs://bucket/file.wav"
|
||||
|
||||
if input_type == "input_audio":
|
||||
audio_content.append({
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": encoded_string, "format": audio_format},
|
||||
})
|
||||
elif input_type == "audio_url":
|
||||
audio_content.append(
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"file_id": test_file_id,
|
||||
"filename": "my-sample-audio-file",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
raw_request = return_raw_request(
|
||||
endpoint=CallTypes.completion,
|
||||
kwargs={
|
||||
**base_completion_call_args,
|
||||
"modalities": ["text", "audio"],
|
||||
"audio": {"voice": "alloy", "format": audio_format},
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": audio_content,
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
print("raw_request: ", raw_request)
|
||||
|
||||
if input_type == "input_audio":
|
||||
assert encoded_string in json.dumps(raw_request), "Audio data not sent to gemini"
|
||||
elif input_type == "audio_url":
|
||||
assert test_file_id in json.dumps(raw_request), "Audio URL not sent to gemini"
|
||||
|
||||
class BaseOSeriesModelsTest(ABC): # test across azure/openai
|
||||
@abstractmethod
|
||||
|
@ -1089,3 +1158,5 @@ class BaseAnthropicChatTest(ABC):
|
|||
)
|
||||
|
||||
print(response)
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ from litellm.llms.vertex_ai.context_caching.transformation import (
|
|||
|
||||
class TestGoogleAIStudioGemini(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "gemini/gemini-1.5-flash-002"}
|
||||
return {"model": "gemini/gemini-2.0-flash"}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
|
|
|
@ -84,12 +84,14 @@ async def test_audio_output_from_model(stream):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
async def test_audio_input_to_model(stream):
|
||||
@pytest.mark.parametrize("model", ["gpt-4o-audio-preview"]) # "gpt-4o-audio-preview",
|
||||
async def test_audio_input_to_model(stream, model):
|
||||
# Fetch the audio file and convert it to a base64 encoded string
|
||||
audio_format = "pcm16"
|
||||
if stream is False:
|
||||
audio_format = "wav"
|
||||
litellm._turn_on_debug()
|
||||
litellm.drop_params = True
|
||||
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
@ -97,7 +99,7 @@ async def test_audio_input_to_model(stream):
|
|||
encoded_string = base64.b64encode(wav_data).decode("utf-8")
|
||||
try:
|
||||
completion = await litellm.acompletion(
|
||||
model="gpt-4o-audio-preview",
|
||||
model=model,
|
||||
modalities=["text", "audio"],
|
||||
audio={"voice": "alloy", "format": audio_format},
|
||||
stream=stream,
|
||||
|
@ -120,6 +122,7 @@ async def test_audio_input_to_model(stream):
|
|||
except Exception as e:
|
||||
if "openai-internal" in str(e):
|
||||
pytest.skip("Skipping test due to openai-internal error")
|
||||
raise e
|
||||
if stream is True:
|
||||
await check_streaming_response(completion)
|
||||
else:
|
||||
|
|
|
@ -2285,10 +2285,10 @@ def test_update_logs_with_spend_logs_url(prisma_client):
|
|||
"""
|
||||
Unit test for making sure spend logs list is still updated when url passed in
|
||||
"""
|
||||
from litellm.proxy.proxy_server import _set_spend_logs_payload
|
||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||
|
||||
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
|
||||
_set_spend_logs_payload(payload=payload, prisma_client=prisma_client)
|
||||
DBSpendUpdateWriter._set_spend_logs_payload(payload=payload, prisma_client=prisma_client)
|
||||
|
||||
assert len(prisma_client.spend_log_transactions) > 0
|
||||
|
||||
|
@ -2296,7 +2296,7 @@ def test_update_logs_with_spend_logs_url(prisma_client):
|
|||
|
||||
spend_logs_url = ""
|
||||
payload = {"startTime": datetime.now(), "endTime": datetime.now()}
|
||||
_set_spend_logs_payload(
|
||||
DBSpendUpdateWriter._set_spend_logs_payload(
|
||||
payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue