diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index 9923c0e45d..df3d028c99 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -19,6 +19,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse, Usage +from litellm.utils import token_counter from ..common_utils import SagemakerError @@ -238,9 +239,12 @@ class SagemakerConfig(BaseConfig): ) ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) + prompt_tokens = token_counter( + text=prompt, count_response_tokens=True + ) # doesn't apply any default token count from openai's chat template + completion_tokens = token_counter( + text=model_response["choices"][0]["message"].get("content", ""), + count_response_tokens=True, ) model_response.created = int(time.time()) diff --git a/litellm/llms/vertex_ai/gemini/transformation.py b/litellm/llms/vertex_ai/gemini/transformation.py index 96b33ee187..8067d51c87 100644 --- a/litellm/llms/vertex_ai/gemini/transformation.py +++ b/litellm/llms/vertex_ai/gemini/transformation.py @@ -27,6 +27,8 @@ from litellm.types.files import ( from litellm.types.llms.openai import ( AllMessageValues, ChatCompletionAssistantMessage, + ChatCompletionAudioObject, + ChatCompletionFileObject, ChatCompletionImageObject, ChatCompletionTextObject, ) @@ -103,24 +105,53 @@ def _get_image_mime_type_from_url(url: str) -> Optional[str]: See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements Supported by Gemini: - - PNG (`image/png`) - - JPEG (`image/jpeg`) - - WebP (`image/webp`) - Example: - url = https://example.com/image.jpg - Returns: image/jpeg + application/pdf + audio/mpeg + audio/mp3 + audio/wav + image/png + image/jpeg + image/webp + text/plain + video/mov + video/mpeg + video/mp4 + video/mpg + video/avi + video/wmv + video/mpegps + video/flv """ url = url.lower() - if url.endswith((".jpg", ".jpeg")): - return "image/jpeg" - elif url.endswith(".png"): - return "image/png" - elif url.endswith(".webp"): - return "image/webp" - elif url.endswith(".mp4"): - return "video/mp4" - elif url.endswith(".pdf"): - return "application/pdf" + + # Map file extensions to mime types + mime_types = { + # Images + (".jpg", ".jpeg"): "image/jpeg", + (".png",): "image/png", + (".webp",): "image/webp", + # Videos + (".mp4",): "video/mp4", + (".mov",): "video/mov", + (".mpeg", ".mpg"): "video/mpeg", + (".avi",): "video/avi", + (".wmv",): "video/wmv", + (".mpegps",): "video/mpegps", + (".flv",): "video/flv", + # Audio + (".mp3",): "audio/mp3", + (".wav",): "audio/wav", + (".mpeg",): "audio/mpeg", + # Documents + (".pdf",): "application/pdf", + (".txt",): "text/plain", + } + + # Check each extension group against the URL + for extensions, mime_type in mime_types.items(): + if any(url.endswith(ext) for ext in extensions): + return mime_type + return None @@ -152,7 +183,7 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915 _message_content = messages[msg_i].get("content") if _message_content is not None and isinstance(_message_content, list): _parts: List[PartType] = [] - for element in _message_content: + for element_idx, element in enumerate(_message_content): if ( element["type"] == "text" and "text" in element @@ -174,6 +205,41 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915 image_url=image_url, format=format ) _parts.append(_part) + elif element["type"] == "input_audio": + audio_element = cast(ChatCompletionAudioObject, element) + if audio_element["input_audio"].get("data") is not None: + _part = PartType( + inline_data=BlobType( + data=audio_element["input_audio"]["data"], + mime_type="audio/{}".format( + audio_element["input_audio"]["format"] + ), + ) + ) + _parts.append(_part) + elif element["type"] == "file": + file_element = cast(ChatCompletionFileObject, element) + file_id = file_element["file"].get("file_id") + format = file_element["file"].get("format") + + if not file_id: + continue + mime_type = format or _get_image_mime_type_from_url(file_id) + + if mime_type is not None: + _part = PartType( + file_data=FileDataType( + file_uri=file_id, + mime_type=mime_type, + ) + ) + _parts.append(_part) + else: + raise Exception( + "Unable to determine mime type for file_id: {}, set this explicitly using message[{}].content[{}].file.format".format( + file_id, msg_i, element_idx + ) + ) user_content.extend(_parts) elif ( _message_content is not None diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 4ce592d7b0..64525d660b 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -4696,6 +4696,8 @@ "supports_vision": true, "supports_response_schema": true, "supports_audio_output": true, + "supports_audio_input": true, + "supported_modalities": ["text", "image", "audio", "video"], "supports_tool_choice": true, "source": "https://ai.google.dev/pricing#2_0flash" }, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 72edeb55ca..6c57afc76e 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -16,14 +16,18 @@ model_list: - model_name: "bedrock-nova" litellm_params: model: us.amazon.nova-pro-v1:0 + - model_name: "gemini-2.0-flash" + litellm_params: + model: gemini/gemini-2.0-flash + api_key: os.environ/GEMINI_API_KEY litellm_settings: num_retries: 0 callbacks: ["prometheus"] - json_logs: true + # json_logs: true -router_settings: - routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE - redis_host: os.environ/REDIS_HOST - redis_password: os.environ/REDIS_PASSWORD - redis_port: os.environ/REDIS_PORT +# router_settings: +# routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE +# redis_host: os.environ/REDIS_HOST +# redis_password: os.environ/REDIS_PASSWORD +# redis_port: os.environ/REDIS_PORT diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 90444013a8..8124b7fd20 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -1370,22 +1370,6 @@ async def get_user_daily_activity( default=None, description="End date in YYYY-MM-DD format", ), - group_by: List[GroupByDimension] = fastapi.Query( - default=[GroupByDimension.DATE], - description="Dimensions to group by. Can combine multiple (e.g. date,team)", - ), - view_by: Literal["team", "organization", "user"] = fastapi.Query( - default="user", - description="View spend at team/org/user level", - ), - team_id: Optional[str] = fastapi.Query( - default=None, - description="Filter by specific team", - ), - org_id: Optional[str] = fastapi.Query( - default=None, - description="Filter by specific organization", - ), model: Optional[str] = fastapi.Query( default=None, description="Filter by specific model", @@ -1408,13 +1392,13 @@ async def get_user_daily_activity( Meant to optimize querying spend data for analytics for a user. Returns: - (by date/team/org/user/model/api_key/model_group/provider) + (by date) - spend - prompt_tokens - completion_tokens - total_tokens - api_requests - - breakdown by team, organization, user, model, api_key, model_group, provider + - breakdown by model, api_key, provider """ from litellm.proxy.proxy_server import prisma_client @@ -1439,10 +1423,6 @@ async def get_user_daily_activity( } } - if team_id: - where_conditions["team_id"] = team_id - if org_id: - where_conditions["organization_id"] = org_id if model: where_conditions["model"] = model if api_key: diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 3ba5a3a4e0..6378d02888 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -505,10 +505,11 @@ class ChatCompletionDocumentObject(TypedDict): citations: Optional[CitationsObject] -class ChatCompletionFileObjectFile(TypedDict): - file_data: Optional[str] - file_id: Optional[str] - filename: Optional[str] +class ChatCompletionFileObjectFile(TypedDict, total=False): + file_data: str + file_id: str + filename: str + format: str class ChatCompletionFileObject(TypedDict): diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 4ce592d7b0..64525d660b 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -4696,6 +4696,8 @@ "supports_vision": true, "supports_response_schema": true, "supports_audio_output": true, + "supports_audio_input": true, + "supported_modalities": ["text", "image", "audio", "video"], "supports_tool_choice": true, "source": "https://ai.google.dev/pricing#2_0flash" }, diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index eb58512197..432fd848b5 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, Mock, patch import os import uuid import time +import base64 sys.path.insert( 0, os.path.abspath("../..") @@ -889,6 +890,74 @@ class BaseLLMChatTest(ABC): assert cost > 0 + @pytest.mark.parametrize("input_type", ["input_audio", "audio_url"]) + def test_supports_audio_input(self, input_type): + from litellm.utils import return_raw_request, supports_audio_input + from litellm.types.utils import CallTypes + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + + litellm.drop_params = True + base_completion_call_args = self.get_base_completion_call_args() + if not supports_audio_input(base_completion_call_args["model"], None): + print("Model does not support audio input") + pytest.skip("Model does not support audio input") + + url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav" + response = httpx.get(url) + response.raise_for_status() + wav_data = response.content + audio_format = "wav" + encoded_string = base64.b64encode(wav_data).decode("utf-8") + + audio_content = [ + { + "type": "text", + "text": "What is in this recording?" + } + ] + + test_file_id = "gs://bucket/file.wav" + + if input_type == "input_audio": + audio_content.append({ + "type": "input_audio", + "input_audio": {"data": encoded_string, "format": audio_format}, + }) + elif input_type == "audio_url": + audio_content.append( + { + "type": "file", + "file": { + "file_id": test_file_id, + "filename": "my-sample-audio-file", + } + } + ) + + + + raw_request = return_raw_request( + endpoint=CallTypes.completion, + kwargs={ + **base_completion_call_args, + "modalities": ["text", "audio"], + "audio": {"voice": "alloy", "format": audio_format}, + "messages": [ + { + "role": "user", + "content": audio_content, + }, + ] + } + ) + print("raw_request: ", raw_request) + + if input_type == "input_audio": + assert encoded_string in json.dumps(raw_request), "Audio data not sent to gemini" + elif input_type == "audio_url": + assert test_file_id in json.dumps(raw_request), "Audio URL not sent to gemini" class BaseOSeriesModelsTest(ABC): # test across azure/openai @abstractmethod @@ -1089,3 +1158,5 @@ class BaseAnthropicChatTest(ABC): ) print(response) + + \ No newline at end of file diff --git a/tests/llm_translation/test_gemini.py b/tests/llm_translation/test_gemini.py index f592337593..7c7c10daee 100644 --- a/tests/llm_translation/test_gemini.py +++ b/tests/llm_translation/test_gemini.py @@ -15,7 +15,7 @@ from litellm.llms.vertex_ai.context_caching.transformation import ( class TestGoogleAIStudioGemini(BaseLLMChatTest): def get_base_completion_call_args(self) -> dict: - return {"model": "gemini/gemini-1.5-flash-002"} + return {"model": "gemini/gemini-2.0-flash"} def test_tool_call_no_arguments(self, tool_call_no_arguments): """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" diff --git a/tests/llm_translation/test_gpt4o_audio.py b/tests/llm_translation/test_gpt4o_audio.py index 822cfb0356..f41dabb666 100644 --- a/tests/llm_translation/test_gpt4o_audio.py +++ b/tests/llm_translation/test_gpt4o_audio.py @@ -84,12 +84,14 @@ async def test_audio_output_from_model(stream): @pytest.mark.asyncio @pytest.mark.parametrize("stream", [True, False]) -async def test_audio_input_to_model(stream): +@pytest.mark.parametrize("model", ["gpt-4o-audio-preview"]) # "gpt-4o-audio-preview", +async def test_audio_input_to_model(stream, model): # Fetch the audio file and convert it to a base64 encoded string audio_format = "pcm16" if stream is False: audio_format = "wav" litellm._turn_on_debug() + litellm.drop_params = True url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav" response = requests.get(url) response.raise_for_status() @@ -97,7 +99,7 @@ async def test_audio_input_to_model(stream): encoded_string = base64.b64encode(wav_data).decode("utf-8") try: completion = await litellm.acompletion( - model="gpt-4o-audio-preview", + model=model, modalities=["text", "audio"], audio={"voice": "alloy", "format": audio_format}, stream=stream, @@ -120,6 +122,7 @@ async def test_audio_input_to_model(stream): except Exception as e: if "openai-internal" in str(e): pytest.skip("Skipping test due to openai-internal error") + raise e if stream is True: await check_streaming_response(completion) else: diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index 99fc415c1b..ea837b717b 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -2285,10 +2285,10 @@ def test_update_logs_with_spend_logs_url(prisma_client): """ Unit test for making sure spend logs list is still updated when url passed in """ - from litellm.proxy.proxy_server import _set_spend_logs_payload + from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter payload = {"startTime": datetime.now(), "endTime": datetime.now()} - _set_spend_logs_payload(payload=payload, prisma_client=prisma_client) + DBSpendUpdateWriter._set_spend_logs_payload(payload=payload, prisma_client=prisma_client) assert len(prisma_client.spend_log_transactions) > 0 @@ -2296,7 +2296,7 @@ def test_update_logs_with_spend_logs_url(prisma_client): spend_logs_url = "" payload = {"startTime": datetime.now(), "endTime": datetime.now()} - _set_spend_logs_payload( + DBSpendUpdateWriter._set_spend_logs_payload( payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client )