diff --git a/litellm/litellm_core_utils/get_litellm_params.py b/litellm/litellm_core_utils/get_litellm_params.py index 4f2f43f0de..f40f1ae4c7 100644 --- a/litellm/litellm_core_utils/get_litellm_params.py +++ b/litellm/litellm_core_utils/get_litellm_params.py @@ -110,5 +110,8 @@ def get_litellm_params( "azure_password": kwargs.get("azure_password"), "max_retries": max_retries, "timeout": kwargs.get("timeout"), + "bucket_name": kwargs.get("bucket_name"), + "vertex_credentials": kwargs.get("vertex_credentials"), + "vertex_project": kwargs.get("vertex_project"), } return litellm_params diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 0f2d0da388..44b680d487 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -2,7 +2,10 @@ Common utility functions used for translating messages across providers """ -from typing import Dict, List, Literal, Optional, Union, cast +import io +import mimetypes +from os import PathLike +from typing import Dict, List, Literal, Mapping, Optional, Union, cast from litellm.types.llms.openai import ( AllMessageValues, @@ -10,7 +13,13 @@ from litellm.types.llms.openai import ( ChatCompletionFileObject, ChatCompletionUserMessage, ) -from litellm.types.utils import Choices, ModelResponse, StreamingChoices +from litellm.types.utils import ( + Choices, + ExtractedFileData, + FileTypes, + ModelResponse, + StreamingChoices, +) DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage( content="Please continue.", role="user" @@ -350,6 +359,68 @@ def update_messages_with_model_file_ids( return messages +def extract_file_data(file_data: FileTypes) -> ExtractedFileData: + """ + Extracts and processes file data from various input formats. + + Args: + file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content + + Returns: + ExtractedFileData containing: + - filename: Name of the file if provided + - content: The file content in bytes + - content_type: MIME type of the file + - headers: Any additional headers + """ + # Parse the file_data based on its type + filename = None + file_content = None + content_type = None + file_headers: Mapping[str, str] = {} + + if isinstance(file_data, tuple): + if len(file_data) == 2: + filename, file_content = file_data + elif len(file_data) == 3: + filename, file_content, content_type = file_data + elif len(file_data) == 4: + filename, file_content, content_type, file_headers = file_data + else: + file_content = file_data + # Convert content to bytes + if isinstance(file_content, (str, PathLike)): + # If it's a path, open and read the file + with open(file_content, "rb") as f: + content = f.read() + elif isinstance(file_content, io.IOBase): + # If it's a file-like object + content = file_content.read() + + if isinstance(content, str): + content = content.encode("utf-8") + # Reset file pointer to beginning + file_content.seek(0) + elif isinstance(file_content, bytes): + content = file_content + else: + raise ValueError(f"Unsupported file content type: {type(file_content)}") + + # Use provided content type or guess based on filename + if not content_type: + content_type = ( + mimetypes.guess_type(filename)[0] + if filename + else "application/octet-stream" + ) + + return ExtractedFileData( + filename=filename, + content=content, + content_type=content_type, + headers=file_headers, + ) + def unpack_defs(schema, defs): properties = schema.get("properties", None) if properties is None: @@ -381,3 +452,4 @@ def unpack_defs(schema, defs): unpack_defs(ref, defs) value["items"] = ref continue + diff --git a/litellm/llms/aiohttp_openai/chat/transformation.py b/litellm/llms/aiohttp_openai/chat/transformation.py index af073fe8e3..c2d4e5adcd 100644 --- a/litellm/llms/aiohttp_openai/chat/transformation.py +++ b/litellm/llms/aiohttp_openai/chat/transformation.py @@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index f2a5542dcd..44567facf9 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -301,6 +301,7 @@ class AnthropicChatCompletion(BaseLLM): model=model, messages=messages, optional_params={**optional_params, "is_vertex_request": is_vertex_request}, + litellm_params=litellm_params, ) config = ProviderConfigManager.get_provider_chat_config( diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 8a2048f95a..9b66249630 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -868,6 +868,7 @@ class AnthropicConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py index e4e04df4d6..9e3287aa8a 100644 --- a/litellm/llms/anthropic/completion/transformation.py +++ b/litellm/llms/anthropic/completion/transformation.py @@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/azure/chat/gpt_transformation.py b/litellm/llms/azure/chat/gpt_transformation.py index e30d68f97d..ea61ef2c9a 100644 --- a/litellm/llms/azure/chat/gpt_transformation.py +++ b/litellm/llms/azure/chat/gpt_transformation.py @@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/azure_ai/chat/transformation.py b/litellm/llms/azure_ai/chat/transformation.py index 007a4303c8..839f875f75 100644 --- a/litellm/llms/azure_ai/chat/transformation.py +++ b/litellm/llms/azure_ai/chat/transformation.py @@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 5279a44201..fa278c805e 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -262,6 +262,7 @@ class BaseConfig(ABC): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/base_llm/files/transformation.py b/litellm/llms/base_llm/files/transformation.py index 0f1f46352f..9925004c89 100644 --- a/litellm/llms/base_llm/files/transformation.py +++ b/litellm/llms/base_llm/files/transformation.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union import httpx @@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig): ) -> List[OpenAICreateFileRequestOptionalParams]: pass - def get_complete_url( + def get_complete_file_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, - stream: Optional[bool] = None, - ) -> str: - """ - OPTIONAL - - Get the complete url for the request - - Some providers need `model` in `api_base` - """ - return api_base or "" + data: CreateFileRequest, + ): + return self.get_complete_url( + api_base=api_base, + api_key=api_key, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + ) @abstractmethod def transform_create_file_request( @@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig): create_file_data: CreateFileRequest, optional_params: dict, litellm_params: dict, - ) -> dict: + ) -> Union[dict, str, bytes]: pass @abstractmethod diff --git a/litellm/llms/base_llm/image_variations/transformation.py b/litellm/llms/base_llm/image_variations/transformation.py index 3ed446a84e..60444d0fb7 100644 --- a/litellm/llms/base_llm/image_variations/transformation.py +++ b/litellm/llms/base_llm/image_variations/transformation.py @@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 8ce2c4818b..fbe2dc4937 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -831,6 +831,7 @@ class AmazonConverseConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index cb12f779cc..67194e83e7 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/clarifai/chat/transformation.py b/litellm/llms/clarifai/chat/transformation.py index 916da73883..73be89fc6e 100644 --- a/litellm/llms/clarifai/chat/transformation.py +++ b/litellm/llms/clarifai/chat/transformation.py @@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/cloudflare/chat/transformation.py b/litellm/llms/cloudflare/chat/transformation.py index 1874bb5115..9e59782bf7 100644 --- a/litellm/llms/cloudflare/chat/transformation.py +++ b/litellm/llms/cloudflare/chat/transformation.py @@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py index 70677214a7..5dd44aca80 100644 --- a/litellm/llms/cohere/chat/transformation.py +++ b/litellm/llms/cohere/chat/transformation.py @@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/cohere/completion/transformation.py b/litellm/llms/cohere/completion/transformation.py index bdfcda020e..f96ef89d3c 100644 --- a/litellm/llms/cohere/completion/transformation.py +++ b/litellm/llms/cohere/completion/transformation.py @@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py index 72092cf261..13141fc19a 100644 --- a/litellm/llms/custom_httpx/aiohttp_handler.py +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -229,6 +229,7 @@ class BaseLLMAIOHTTPHandler: model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, api_base=api_base, ) @@ -498,6 +499,7 @@ class BaseLLMAIOHTTPHandler: model=model, messages=[{"role": "user", "content": "test"}], optional_params=optional_params, + litellm_params=litellm_params, api_base=api_base, ) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 23d7fe4b4d..f1aa5627dc 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -192,7 +192,7 @@ class AsyncHTTPHandler: async def post( self, url: str, - data: Optional[Union[dict, str]] = None, # type: ignore + data: Optional[Union[dict, str, bytes]] = None, # type: ignore json: Optional[dict] = None, params: Optional[dict] = None, headers: Optional[dict] = None, @@ -427,7 +427,7 @@ class AsyncHTTPHandler: self, url: str, client: httpx.AsyncClient, - data: Optional[Union[dict, str]] = None, # type: ignore + data: Optional[Union[dict, str, bytes]] = None, # type: ignore json: Optional[dict] = None, params: Optional[dict] = None, headers: Optional[dict] = None, @@ -527,7 +527,7 @@ class HTTPHandler: def post( self, url: str, - data: Optional[Union[dict, str]] = None, + data: Optional[Union[dict, str, bytes]] = None, json: Optional[Union[dict, str, List]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, @@ -573,7 +573,6 @@ class HTTPHandler: setattr(e, "text", error_text) setattr(e, "status_code", e.response.status_code) - raise e except Exception as e: raise e diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 5778f0228f..b7c72e89ef 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -247,6 +247,7 @@ class BaseLLMHTTPHandler: messages=messages, optional_params=optional_params, api_base=api_base, + litellm_params=litellm_params, ) api_base = provider_config.get_complete_url( @@ -625,6 +626,7 @@ class BaseLLMHTTPHandler: model=model, messages=[], optional_params=optional_params, + litellm_params=litellm_params, ) api_base = provider_config.get_complete_url( @@ -896,6 +898,7 @@ class BaseLLMHTTPHandler: model=model, messages=[], optional_params=optional_params, + litellm_params=litellm_params, ) if client is None or not isinstance(client, HTTPHandler): @@ -1228,15 +1231,19 @@ class BaseLLMHTTPHandler: model="", messages=[], optional_params={}, + litellm_params=litellm_params, ) - api_base = provider_config.get_complete_url( + api_base = provider_config.get_complete_file_url( api_base=api_base, api_key=api_key, model="", optional_params={}, litellm_params=litellm_params, + data=create_file_data, ) + if api_base is None: + raise ValueError("api_base is required for create_file") # Get the transformed request data for both steps transformed_request = provider_config.transform_create_file_request( @@ -1263,48 +1270,57 @@ class BaseLLMHTTPHandler: else: sync_httpx_client = client - try: - # Step 1: Initial request to get upload URL - initial_response = sync_httpx_client.post( - url=api_base, - headers={ - **headers, - **transformed_request["initial_request"]["headers"], - }, - data=json.dumps(transformed_request["initial_request"]["data"]), - timeout=timeout, - ) - - # Extract upload URL from response headers - upload_url = initial_response.headers.get("X-Goog-Upload-URL") - - if not upload_url: - raise ValueError("Failed to get upload URL from initial request") - - # Step 2: Upload the actual file + if isinstance(transformed_request, str) or isinstance( + transformed_request, bytes + ): upload_response = sync_httpx_client.post( - url=upload_url, - headers=transformed_request["upload_request"]["headers"], - data=transformed_request["upload_request"]["data"], + url=api_base, + headers=headers, + data=transformed_request, timeout=timeout, ) + else: + try: + # Step 1: Initial request to get upload URL + initial_response = sync_httpx_client.post( + url=api_base, + headers={ + **headers, + **transformed_request["initial_request"]["headers"], + }, + data=json.dumps(transformed_request["initial_request"]["data"]), + timeout=timeout, + ) - return provider_config.transform_create_file_response( - model=None, - raw_response=upload_response, - logging_obj=logging_obj, - litellm_params=litellm_params, - ) + # Extract upload URL from response headers + upload_url = initial_response.headers.get("X-Goog-Upload-URL") - except Exception as e: - raise self._handle_error( - e=e, - provider_config=provider_config, - ) + if not upload_url: + raise ValueError("Failed to get upload URL from initial request") + + # Step 2: Upload the actual file + upload_response = sync_httpx_client.post( + url=upload_url, + headers=transformed_request["upload_request"]["headers"], + data=transformed_request["upload_request"]["data"], + timeout=timeout, + ) + except Exception as e: + raise self._handle_error( + e=e, + provider_config=provider_config, + ) + + return provider_config.transform_create_file_response( + model=None, + raw_response=upload_response, + logging_obj=logging_obj, + litellm_params=litellm_params, + ) async def async_create_file( self, - transformed_request: dict, + transformed_request: Union[bytes, str, dict], litellm_params: dict, provider_config: BaseFilesConfig, headers: dict, @@ -1323,45 +1339,54 @@ class BaseLLMHTTPHandler: else: async_httpx_client = client - try: - # Step 1: Initial request to get upload URL - initial_response = await async_httpx_client.post( - url=api_base, - headers={ - **headers, - **transformed_request["initial_request"]["headers"], - }, - data=json.dumps(transformed_request["initial_request"]["data"]), - timeout=timeout, - ) - - # Extract upload URL from response headers - upload_url = initial_response.headers.get("X-Goog-Upload-URL") - - if not upload_url: - raise ValueError("Failed to get upload URL from initial request") - - # Step 2: Upload the actual file + if isinstance(transformed_request, str) or isinstance( + transformed_request, bytes + ): upload_response = await async_httpx_client.post( - url=upload_url, - headers=transformed_request["upload_request"]["headers"], - data=transformed_request["upload_request"]["data"], + url=api_base, + headers=headers, + data=transformed_request, timeout=timeout, ) + else: + try: + # Step 1: Initial request to get upload URL + initial_response = await async_httpx_client.post( + url=api_base, + headers={ + **headers, + **transformed_request["initial_request"]["headers"], + }, + data=json.dumps(transformed_request["initial_request"]["data"]), + timeout=timeout, + ) - return provider_config.transform_create_file_response( - model=None, - raw_response=upload_response, - logging_obj=logging_obj, - litellm_params=litellm_params, - ) + # Extract upload URL from response headers + upload_url = initial_response.headers.get("X-Goog-Upload-URL") - except Exception as e: - verbose_logger.exception(f"Error creating file: {e}") - raise self._handle_error( - e=e, - provider_config=provider_config, - ) + if not upload_url: + raise ValueError("Failed to get upload URL from initial request") + + # Step 2: Upload the actual file + upload_response = await async_httpx_client.post( + url=upload_url, + headers=transformed_request["upload_request"]["headers"], + data=transformed_request["upload_request"]["data"], + timeout=timeout, + ) + except Exception as e: + verbose_logger.exception(f"Error creating file: {e}") + raise self._handle_error( + e=e, + provider_config=provider_config, + ) + + return provider_config.transform_create_file_response( + model=None, + raw_response=upload_response, + logging_obj=logging_obj, + litellm_params=litellm_params, + ) def list_files(self): """ diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py index 1940f09608..6f5738fb4b 100644 --- a/litellm/llms/databricks/chat/transformation.py +++ b/litellm/llms/databricks/chat/transformation.py @@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/deepgram/audio_transcription/transformation.py b/litellm/llms/deepgram/audio_transcription/transformation.py index b4803576e0..f1b18808f7 100644 --- a/litellm/llms/deepgram/audio_transcription/transformation.py +++ b/litellm/llms/deepgram/audio_transcription/transformation.py @@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/fireworks_ai/common_utils.py b/litellm/llms/fireworks_ai/common_utils.py index 293403b133..17aa67b525 100644 --- a/litellm/llms/fireworks_ai/common_utils.py +++ b/litellm/llms/fireworks_ai/common_utils.py @@ -41,6 +41,7 @@ class FireworksAIMixin: model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/gemini/common_utils.py b/litellm/llms/gemini/common_utils.py index ace24e982f..fef41f7d58 100644 --- a/litellm/llms/gemini/common_utils.py +++ b/litellm/llms/gemini/common_utils.py @@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/gemini/files/transformation.py b/litellm/llms/gemini/files/transformation.py index a1f99c6903..e98e76dabc 100644 --- a/litellm/llms/gemini/files/transformation.py +++ b/litellm/llms/gemini/files/transformation.py @@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API. For vertex ai, check out the vertex_ai/files/handler.py file. """ import time -from typing import List, Mapping, Optional +from typing import List, Optional import httpx from litellm._logging import verbose_logger +from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data from litellm.llms.base_llm.files.transformation import ( BaseFilesConfig, LiteLLMLoggingObj, @@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig): if file_data is None: raise ValueError("File data is required") - # Parse the file_data based on its type - filename = None - file_content = None - content_type = None - file_headers: Mapping[str, str] = {} - - if isinstance(file_data, tuple): - if len(file_data) == 2: - filename, file_content = file_data - elif len(file_data) == 3: - filename, file_content, content_type = file_data - elif len(file_data) == 4: - filename, file_content, content_type, file_headers = file_data - else: - file_content = file_data - - # Handle the file content based on its type - import io - from os import PathLike - - # Convert content to bytes - if isinstance(file_content, (str, PathLike)): - # If it's a path, open and read the file - with open(file_content, "rb") as f: - content = f.read() - elif isinstance(file_content, io.IOBase): - # If it's a file-like object - content = file_content.read() - if isinstance(content, str): - content = content.encode("utf-8") - elif isinstance(file_content, bytes): - content = file_content - else: - raise ValueError(f"Unsupported file content type: {type(file_content)}") + # Use the common utility function to extract file data + extracted_data = extract_file_data(file_data) # Get file size - file_size = len(content) - - # Use provided content type or guess based on filename - if not content_type: - import mimetypes - - content_type = ( - mimetypes.guess_type(filename)[0] - if filename - else "application/octet-stream" - ) + file_size = len(extracted_data["content"]) # Step 1: Initial resumable upload request headers = { "X-Goog-Upload-Protocol": "resumable", "X-Goog-Upload-Command": "start", "X-Goog-Upload-Header-Content-Length": str(file_size), - "X-Goog-Upload-Header-Content-Type": content_type, + "X-Goog-Upload-Header-Content-Type": extracted_data["content_type"], "Content-Type": "application/json", } - headers.update(file_headers) # Add any custom headers + headers.update(extracted_data["headers"]) # Add any custom headers # Initial metadata request body - initial_data = {"file": {"display_name": filename or str(int(time.time()))}} + initial_data = { + "file": { + "display_name": extracted_data["filename"] or str(int(time.time())) + } + } # Step 2: Actual file upload data upload_headers = { @@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig): return { "initial_request": {"headers": headers, "data": initial_data}, - "upload_request": {"headers": upload_headers, "data": content}, + "upload_request": { + "headers": upload_headers, + "data": extracted_data["content"], + }, } def transform_create_file_response( diff --git a/litellm/llms/huggingface/chat/transformation.py b/litellm/llms/huggingface/chat/transformation.py index c84f03ab93..0ad93be763 100644 --- a/litellm/llms/huggingface/chat/transformation.py +++ b/litellm/llms/huggingface/chat/transformation.py @@ -1,6 +1,6 @@ import logging import os -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx @@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping - logger = logging.getLogger(__name__) BASE_URL = "https://router.huggingface.co" @@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig): headers: dict, model: str, messages: List[AllMessageValues], - optional_params: dict, + optional_params: Dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: @@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig): def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: - return HuggingFaceError(status_code=status_code, message=error_message, headers=headers) + return HuggingFaceError( + status_code=status_code, message=error_message, headers=headers + ) def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]: """ @@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig): if api_base is not None: complete_url = api_base elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"): - complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE")) + complete_url = str(os.getenv("HF_API_BASE")) or str( + os.getenv("HUGGINGFACE_API_BASE") + ) elif model.startswith(("http://", "https://")): complete_url = model # 4. Default construction with provider @@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig): ) mapped_model = provider_mapping["providerId"] messages = self._transform_messages(messages=messages, model=mapped_model) - return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params)) + return dict( + ChatCompletionRequest( + model=mapped_model, messages=messages, **optional_params + ) + ) diff --git a/litellm/llms/huggingface/embedding/handler.py b/litellm/llms/huggingface/embedding/handler.py index 7277fbd0e3..bfd73c1346 100644 --- a/litellm/llms/huggingface/embedding/handler.py +++ b/litellm/llms/huggingface/embedding/handler.py @@ -1,15 +1,6 @@ import json import os -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - Union, - get_args, -) +from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args import httpx @@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug ] - -def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]: +def get_hf_task_embedding_for_model( + model: str, task_type: Optional[str], api_base: str +) -> Optional[str]: if task_type is not None: if task_type in get_args(hf_tasks_embeddings): return task_type @@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba return pipeline_tag -async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]: +async def async_get_hf_task_embedding_for_model( + model: str, task_type: Optional[str], api_base: str +) -> Optional[str]: if task_type is not None: if task_type in get_args(hf_tasks_embeddings): return task_type @@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM): input: List, optional_params: dict, ) -> dict: - hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) + hf_task = await async_get_hf_task_embedding_for_model( + model=model, task_type=task_type, api_base=HF_HUB_URL + ) data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task) @@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM): task_type = optional_params.pop("input_type", None) if call_type == "sync": - hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) + hf_task = get_hf_task_embedding_for_model( + model=model, task_type=task_type, api_base=HF_HUB_URL + ) elif call_type == "async": return self._async_transform_input( model=model, task_type=task_type, embed_url=embed_url, input=input @@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM): input: list, model_response: EmbeddingResponse, optional_params: dict, + litellm_params: dict, logging_obj: LiteLLMLoggingObj, encoding: Callable, api_key: Optional[str] = None, @@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM): model=model, optional_params=optional_params, messages=[], + litellm_params=litellm_params, ) task_type = optional_params.pop("input_type", None) - task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) + task = get_hf_task_embedding_for_model( + model=model, task_type=task_type, api_base=HF_HUB_URL + ) # print_verbose(f"{model}, {task}") embed_url = "" if "https" in model: @@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM): elif "HUGGINGFACE_API_BASE" in os.environ: embed_url = os.getenv("HUGGINGFACE_API_BASE", "") else: - embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}" + embed_url = ( + f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}" + ) ## ROUTING ## if aembedding is True: diff --git a/litellm/llms/huggingface/embedding/transformation.py b/litellm/llms/huggingface/embedding/transformation.py index f803157768..60bd5dcd61 100644 --- a/litellm/llms/huggingface/embedding/transformation.py +++ b/litellm/llms/huggingface/embedding/transformation.py @@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: Dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: diff --git a/litellm/llms/nlp_cloud/chat/handler.py b/litellm/llms/nlp_cloud/chat/handler.py index b0abdda587..b0563d8b55 100644 --- a/litellm/llms/nlp_cloud/chat/handler.py +++ b/litellm/llms/nlp_cloud/chat/handler.py @@ -36,6 +36,7 @@ def completion( model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) ## Load Config diff --git a/litellm/llms/nlp_cloud/chat/transformation.py b/litellm/llms/nlp_cloud/chat/transformation.py index b7967249ab..8037a45832 100644 --- a/litellm/llms/nlp_cloud/chat/transformation.py +++ b/litellm/llms/nlp_cloud/chat/transformation.py @@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index 64544bd269..789b728337 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/oobabooga/chat/oobabooga.py b/litellm/llms/oobabooga/chat/oobabooga.py index 8829d2233e..5eb68a03d4 100644 --- a/litellm/llms/oobabooga/chat/oobabooga.py +++ b/litellm/llms/oobabooga/chat/oobabooga.py @@ -32,6 +32,7 @@ def completion( model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) if "https" in model: completion_url = model @@ -123,6 +124,7 @@ def embedding( model=model, messages=[], optional_params=optional_params, + litellm_params={}, ) response = litellm.module_level_client.post( embeddings_url, headers=headers, json=data diff --git a/litellm/llms/oobabooga/chat/transformation.py b/litellm/llms/oobabooga/chat/transformation.py index 6fd56f934e..e87b70130c 100644 --- a/litellm/llms/oobabooga/chat/transformation.py +++ b/litellm/llms/oobabooga/chat/transformation.py @@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index fcab43901a..434214639e 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 3b6be1a034..13412ef96a 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/openai/transcriptions/whisper_transformation.py b/litellm/llms/openai/transcriptions/whisper_transformation.py index 2d3d611dac..c0ccc71579 100644 --- a/litellm/llms/openai/transcriptions/whisper_transformation.py +++ b/litellm/llms/openai/transcriptions/whisper_transformation.py @@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py index a9e37d27fc..24910cba8f 100644 --- a/litellm/llms/petals/completion/transformation.py +++ b/litellm/llms/petals/completion/transformation.py @@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/predibase/chat/handler.py b/litellm/llms/predibase/chat/handler.py index cd80fa53e4..79936764ac 100644 --- a/litellm/llms/predibase/chat/handler.py +++ b/litellm/llms/predibase/chat/handler.py @@ -228,10 +228,10 @@ class PredibaseChatCompletion: api_key: str, logging_obj, optional_params: dict, + litellm_params: dict, tenant_id: str, timeout: Union[float, httpx.Timeout], acompletion=None, - litellm_params=None, logger_fn=None, headers: dict = {}, ) -> Union[ModelResponse, CustomStreamWrapper]: @@ -241,6 +241,7 @@ class PredibaseChatCompletion: messages=messages, optional_params=optional_params, model=model, + litellm_params=litellm_params, ) completion_url = "" input_text = "" diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py index 8ef0eea173..9fbb9d6c9e 100644 --- a/litellm/llms/predibase/chat/transformation.py +++ b/litellm/llms/predibase/chat/transformation.py @@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/replicate/chat/handler.py b/litellm/llms/replicate/chat/handler.py index d954416381..e4bb64fed7 100644 --- a/litellm/llms/replicate/chat/handler.py +++ b/litellm/llms/replicate/chat/handler.py @@ -141,6 +141,7 @@ def completion( model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) # Start a prediction and get the prediction URL version_id = replicate_config.model_to_version_id(model) diff --git a/litellm/llms/replicate/chat/transformation.py b/litellm/llms/replicate/chat/transformation.py index 604e6eefe6..4c61086801 100644 --- a/litellm/llms/replicate/chat/transformation.py +++ b/litellm/llms/replicate/chat/transformation.py @@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/sagemaker/completion/handler.py b/litellm/llms/sagemaker/completion/handler.py index 296689c31c..ebd96ac5b1 100644 --- a/litellm/llms/sagemaker/completion/handler.py +++ b/litellm/llms/sagemaker/completion/handler.py @@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM): model: str, data: dict, messages: List[AllMessageValues], + litellm_params: dict, optional_params: dict, aws_region_name: str, extra_headers: Optional[dict] = None, @@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) request = AWSRequest( method="POST", url=api_base, data=encoded_data, headers=headers @@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM): data=data, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, credentials=credentials, aws_region_name=aws_region_name, ) @@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM): "model": model, "data": _data, "optional_params": optional_params, + "litellm_params": litellm_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, @@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM): "model": model, "data": data, "optional_params": optional_params, + "litellm_params": litellm_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, @@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM): "model": model, "data": data, "optional_params": optional_params, + "litellm_params": litellm_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index df3d028c99..bfc0b6e5f6 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/snowflake/chat/transformation.py b/litellm/llms/snowflake/chat/transformation.py index 574c4704cd..2b92911b05 100644 --- a/litellm/llms/snowflake/chat/transformation.py +++ b/litellm/llms/snowflake/chat/transformation.py @@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/topaz/image_variations/transformation.py b/litellm/llms/topaz/image_variations/transformation.py index 4d14f1ad24..afbd89b9bc 100644 --- a/litellm/llms/topaz/image_variations/transformation.py +++ b/litellm/llms/topaz/image_variations/transformation.py @@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py index 49126917f2..21fcf2eefb 100644 --- a/litellm/llms/triton/completion/transformation.py +++ b/litellm/llms/triton/completion/transformation.py @@ -48,6 +48,7 @@ class TritonConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: Dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: diff --git a/litellm/llms/triton/embedding/transformation.py b/litellm/llms/triton/embedding/transformation.py index 4744ec0834..8ab0277e36 100644 --- a/litellm/llms/triton/embedding/transformation.py +++ b/litellm/llms/triton/embedding/transformation.py @@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/vertex_ai/files/handler.py b/litellm/llms/vertex_ai/files/handler.py index 87c1cb8320..a666a2c37f 100644 --- a/litellm/llms/vertex_ai/files/handler.py +++ b/litellm/llms/vertex_ai/files/handler.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Coroutine, Optional, Union import httpx @@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES -from .transformation import VertexAIFilesTransformation +from .transformation import VertexAIJsonlFilesTransformation -vertex_ai_files_transformation = VertexAIFilesTransformation() +vertex_ai_files_transformation = VertexAIJsonlFilesTransformation() class VertexAIFilesHandler(GCSBucketBase): @@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase): timeout=timeout, max_retries=max_retries, ) - - return None # type: ignore + else: + return asyncio.run( + self.async_create_file( + create_file_data=create_file_data, + api_base=api_base, + vertex_credentials=vertex_credentials, + vertex_project=vertex_project, + vertex_location=vertex_location, + timeout=timeout, + max_retries=max_retries, + ) + ) diff --git a/litellm/llms/vertex_ai/files/transformation.py b/litellm/llms/vertex_ai/files/transformation.py index 89c6ff9deb..c795367e48 100644 --- a/litellm/llms/vertex_ai/files/transformation.py +++ b/litellm/llms/vertex_ai/files/transformation.py @@ -1,7 +1,17 @@ import json +import os +import time import uuid from typing import Any, Dict, List, Optional, Tuple, Union +from httpx import Headers, Response + +from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.files.transformation import ( + BaseFilesConfig, + LiteLLMLoggingObj, +) from litellm.llms.vertex_ai.common_utils import ( _convert_vertex_datetime_to_openai_datetime, ) @@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( VertexGeminiConfig, ) from litellm.types.llms.openai import ( + AllMessageValues, CreateFileRequest, FileTypes, + OpenAICreateFileRequestOptionalParams, OpenAIFileObject, PathLike, ) +from litellm.types.llms.vertex_ai import GcsBucketResponse +from litellm.types.utils import ExtractedFileData, LlmProviders + +from ..common_utils import VertexAIError +from ..vertex_llm_base import VertexBase -class VertexAIFilesTransformation(VertexGeminiConfig): +class VertexAIFilesConfig(VertexBase, BaseFilesConfig): + """ + Config for VertexAI Files + """ + + def __init__(self): + self.jsonl_transformation = VertexAIJsonlFilesTransformation() + super().__init__() + + @property + def custom_llm_provider(self) -> LlmProviders: + return LlmProviders.VERTEX_AI + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if not api_key: + api_key, _ = self.get_access_token( + credentials=litellm_params.get("vertex_credentials"), + project_id=litellm_params.get("vertex_project"), + ) + if not api_key: + raise ValueError("api_key is required") + headers["Authorization"] = f"Bearer {api_key}" + return headers + + def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str: + """ + Helper to extract content from various OpenAI file types and return as string. + + Handles: + - Direct content (str, bytes, IO[bytes]) + - Tuple formats: (filename, content, [content_type], [headers]) + - PathLike objects + """ + content: Union[str, bytes] = b"" + # Extract file content from tuple if necessary + if isinstance(openai_file_content, tuple): + # Take the second element which is always the file content + file_content = openai_file_content[1] + else: + file_content = openai_file_content + + # Handle different file content types + if isinstance(file_content, str): + # String content can be used directly + content = file_content + elif isinstance(file_content, bytes): + # Bytes content can be decoded + content = file_content + elif isinstance(file_content, PathLike): # PathLike + with open(str(file_content), "rb") as f: + content = f.read() + elif hasattr(file_content, "read"): # IO[bytes] + # File-like objects need to be read + content = file_content.read() + + # Ensure content is string + if isinstance(content, bytes): + content = content.decode("utf-8") + + return content + + def _get_gcs_object_name_from_batch_jsonl( + self, + openai_jsonl_content: List[Dict[str, Any]], + ) -> str: + """ + Gets a unique GCS object name for the VertexAI batch prediction job + + named as: litellm-vertex-{model}-{uuid} + """ + _model = openai_jsonl_content[0].get("body", {}).get("model", "") + if "publishers/google/models" not in _model: + _model = f"publishers/google/models/{_model}" + object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}" + return object_name + + def get_object_name( + self, extracted_file_data: ExtractedFileData, purpose: str + ) -> str: + """ + Get the object name for the request + """ + extracted_file_data_content = extracted_file_data.get("content") + + if extracted_file_data_content is None: + raise ValueError("file content is required") + + if purpose == "batch": + ## 1. If jsonl, check if there's a model name + file_content = self._get_content_from_openai_file( + extracted_file_data_content + ) + + # Split into lines and parse each line as JSON + openai_jsonl_content = [ + json.loads(line) for line in file_content.splitlines() if line.strip() + ] + if len(openai_jsonl_content) > 0: + return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content) + + ## 2. If not jsonl, return the filename + filename = extracted_file_data.get("filename") + if filename: + return filename + ## 3. If no file name, return timestamp + return str(int(time.time())) + + def get_complete_file_url( + self, + api_base: Optional[str], + api_key: Optional[str], + model: str, + optional_params: Dict, + litellm_params: Dict, + data: CreateFileRequest, + ) -> str: + """ + Get the complete url for the request + """ + bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME") + if not bucket_name: + raise ValueError("GCS bucket_name is required") + file_data = data.get("file") + purpose = data.get("purpose") + if file_data is None: + raise ValueError("file is required") + if purpose is None: + raise ValueError("purpose is required") + extracted_file_data = extract_file_data(file_data) + object_name = self.get_object_name(extracted_file_data, purpose) + endpoint = ( + f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}" + ) + api_base = api_base or "https://storage.googleapis.com" + if not api_base: + raise ValueError("api_base is required") + + return f"{api_base}/{endpoint}" + + def get_supported_openai_params( + self, model: str + ) -> List[OpenAICreateFileRequestOptionalParams]: + return [] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + return optional_params + + def _map_openai_to_vertex_params( + self, + openai_request_body: Dict[str, Any], + ) -> Dict[str, Any]: + """ + wrapper to call VertexGeminiConfig.map_openai_params + """ + from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( + VertexGeminiConfig, + ) + + config = VertexGeminiConfig() + _model = openai_request_body.get("model", "") + vertex_params = config.map_openai_params( + model=_model, + non_default_params=openai_request_body, + optional_params={}, + drop_params=False, + ) + return vertex_params + + def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content( + self, openai_jsonl_content: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Transforms OpenAI JSONL content to VertexAI JSONL content + + jsonl body for vertex is {"request": } + Example Vertex jsonl + {"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}} + {"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}} + """ + + vertex_jsonl_content = [] + for _openai_jsonl_content in openai_jsonl_content: + openai_request_body = _openai_jsonl_content.get("body") or {} + vertex_request_body = _transform_request_body( + messages=openai_request_body.get("messages", []), + model=openai_request_body.get("model", ""), + optional_params=self._map_openai_to_vertex_params(openai_request_body), + custom_llm_provider="vertex_ai", + litellm_params={}, + cached_content=None, + ) + vertex_jsonl_content.append({"request": vertex_request_body}) + return vertex_jsonl_content + + def transform_create_file_request( + self, + model: str, + create_file_data: CreateFileRequest, + optional_params: dict, + litellm_params: dict, + ) -> Union[bytes, str, dict]: + """ + 2 Cases: + 1. Handle basic file upload + 2. Handle batch file upload (.jsonl) + """ + file_data = create_file_data.get("file") + if file_data is None: + raise ValueError("file is required") + extracted_file_data = extract_file_data(file_data) + extracted_file_data_content = extracted_file_data.get("content") + if ( + create_file_data.get("purpose") == "batch" + and extracted_file_data.get("content_type") == "application/jsonl" + and extracted_file_data_content is not None + ): + ## 1. If jsonl, check if there's a model name + file_content = self._get_content_from_openai_file( + extracted_file_data_content + ) + + # Split into lines and parse each line as JSON + openai_jsonl_content = [ + json.loads(line) for line in file_content.splitlines() if line.strip() + ] + vertex_jsonl_content = ( + self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content( + openai_jsonl_content + ) + ) + return json.dumps(vertex_jsonl_content) + elif isinstance(extracted_file_data_content, bytes): + return extracted_file_data_content + else: + raise ValueError("Unsupported file content type") + + def transform_create_file_response( + self, + model: Optional[str], + raw_response: Response, + logging_obj: LiteLLMLoggingObj, + litellm_params: dict, + ) -> OpenAIFileObject: + """ + Transform VertexAI File upload response into OpenAI-style FileObject + """ + response_json = raw_response.json() + + try: + response_object = GcsBucketResponse(**response_json) # type: ignore + except Exception as e: + raise VertexAIError( + status_code=raw_response.status_code, + message=f"Error reading GCS bucket response: {e}", + headers=raw_response.headers, + ) + + gcs_id = response_object.get("id", "") + # Remove the last numeric ID from the path + gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else "" + + return OpenAIFileObject( + purpose=response_object.get("purpose", "batch"), + id=f"gs://{gcs_id}", + filename=response_object.get("name", ""), + created_at=_convert_vertex_datetime_to_openai_datetime( + vertex_datetime=response_object.get("timeCreated", "") + ), + status="uploaded", + bytes=int(response_object.get("size", 0)), + object="file", + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, Headers] + ) -> BaseLLMException: + return VertexAIError( + status_code=status_code, message=error_message, headers=headers + ) + + +class VertexAIJsonlFilesTransformation(VertexGeminiConfig): """ Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests """ diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index e7d3d2b060..b3b7857ea1 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -905,6 +905,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): model: str, messages: List[AllMessageValues], optional_params: Dict, + litellm_params: Dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: @@ -1022,7 +1023,7 @@ class VertexLLM(VertexBase): logging_obj, stream, optional_params: dict, - litellm_params=None, + litellm_params: dict, logger_fn=None, api_base: Optional[str] = None, client: Optional[AsyncHTTPHandler] = None, @@ -1063,6 +1064,7 @@ class VertexLLM(VertexBase): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) ## LOGGING @@ -1149,6 +1151,7 @@ class VertexLLM(VertexBase): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) request_body = await async_transform_request_body(**data) # type: ignore @@ -1322,6 +1325,7 @@ class VertexLLM(VertexBase): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) ## TRANSFORMATION ## diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py index 88d7339449..8aebd83cc4 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py @@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM): optional_params=optional_params, api_key=auth_header, api_base=api_base, + litellm_params=litellm_params, ) ## LOGGING diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py index afa58c7e5c..5bf02ad765 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py @@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/vertex_ai/vertex_llm_base.py b/litellm/llms/vertex_ai/vertex_llm_base.py index 994e46b50b..8f3037c791 100644 --- a/litellm/llms/vertex_ai/vertex_llm_base.py +++ b/litellm/llms/vertex_ai/vertex_llm_base.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple from litellm._logging import verbose_logger from litellm.litellm_core_utils.asyncify import asyncify -from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES @@ -22,7 +21,7 @@ else: GoogleCredentialsObject = Any -class VertexBase(BaseLLM): +class VertexBase: def __init__(self) -> None: super().__init__() self.access_token: Optional[str] = None diff --git a/litellm/llms/voyage/embedding/transformation.py b/litellm/llms/voyage/embedding/transformation.py index df6ef91a41..91811e0392 100644 --- a/litellm/llms/voyage/embedding/transformation.py +++ b/litellm/llms/voyage/embedding/transformation.py @@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/watsonx/chat/handler.py b/litellm/llms/watsonx/chat/handler.py index aeb0167595..45378c5529 100644 --- a/litellm/llms/watsonx/chat/handler.py +++ b/litellm/llms/watsonx/chat/handler.py @@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler): messages=messages, optional_params=optional_params, api_key=api_key, + litellm_params=litellm_params, ) ## UPDATE PAYLOAD (optional params) diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index 4916cd1c75..e13e015add 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -165,6 +165,7 @@ class IBMWatsonXMixin: model: str, messages: List[AllMessageValues], optional_params: Dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: diff --git a/litellm/main.py b/litellm/main.py index cd7d255e21..3f1d9a1e76 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915 optional_params=optional_params, client=client, aembedding=aembedding, + litellm_params=litellm_params_dict, ) elif custom_llm_provider == "bedrock": if isinstance(input, str): diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 2e25f259b0..55273371fc 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -498,6 +498,51 @@ class OutputConfig(TypedDict, total=False): gcsDestination: GcsDestination +class GcsBucketResponse(TypedDict): + """ + TypedDict for GCS bucket upload response + + Attributes: + kind: The kind of item this is. For objects, this is always storage#object + id: The ID of the object + selfLink: The link to this object + mediaLink: The link to download the object + name: The name of the object + bucket: The name of the bucket containing this object + generation: The content generation of this object + metageneration: The metadata generation of this object + contentType: The content type of the object + storageClass: The storage class of the object + size: The size of the object in bytes + md5Hash: The MD5 hash of the object + crc32c: The CRC32c checksum of the object + etag: The ETag of the object + timeCreated: The creation time of the object + updated: The last update time of the object + timeStorageClassUpdated: The time the storage class was last updated + timeFinalized: The time the object was finalized + """ + + kind: Literal["storage#object"] + id: str + selfLink: str + mediaLink: str + name: str + bucket: str + generation: str + metageneration: str + contentType: str + storageClass: str + size: str + md5Hash: str + crc32c: str + etag: str + timeCreated: str + updated: str + timeStorageClassUpdated: str + timeFinalized: str + + class VertexAIBatchPredictionJob(TypedDict): displayName: str model: str diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 8439037758..6f0c26d301 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -2,7 +2,7 @@ import json import time import uuid from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union from aiohttp import FormData from openai._models import BaseModel as OpenAIObject @@ -2170,3 +2170,20 @@ class CreateCredentialItem(CredentialBase): if not values.get("credential_values") and not values.get("model_id"): raise ValueError("Either credential_values or model_id must be set") return values + + +class ExtractedFileData(TypedDict): + """ + TypedDict for storing processed file data + + Attributes: + filename: Name of the file if provided + content: The file content in bytes + content_type: MIME type of the file + headers: Any additional headers for the file + """ + + filename: Optional[str] + content: bytes + content_type: Optional[str] + headers: Mapping[str, str] diff --git a/litellm/utils.py b/litellm/utils.py index f807990f60..f809d8a77b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6517,6 +6517,10 @@ class ProviderConfigManager: ) return GoogleAIStudioFilesHandler() + elif LlmProviders.VERTEX_AI == provider: + from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig + + return VertexAIFilesConfig() return None diff --git a/tests/batches_tests/test_openai_batches_and_files.py b/tests/batches_tests/test_openai_batches_and_files.py index 4669a2def6..b2826419e8 100644 --- a/tests/batches_tests/test_openai_batches_and_files.py +++ b/tests/batches_tests/test_openai_batches_and_files.py @@ -423,25 +423,35 @@ mock_vertex_batch_response = { @pytest.mark.asyncio -async def test_avertex_batch_prediction(): - with patch( +async def test_avertex_batch_prediction(monkeypatch): + monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local") + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + + client = AsyncHTTPHandler() + + async def mock_side_effect(*args, **kwargs): + print("args", args, "kwargs", kwargs) + url = kwargs.get("url", "") + if "files" in url: + mock_response.json.return_value = mock_file_response + elif "batch" in url: + mock_response.json.return_value = mock_vertex_batch_response + mock_response.status_code = 200 + return mock_response + + with patch.object( + client, "post", side_effect=mock_side_effect + ) as mock_post, patch( "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post" - ) as mock_post: + ) as mock_global_post: # Configure mock responses mock_response = MagicMock() mock_response.raise_for_status.return_value = None # Set up different responses for different API calls - async def mock_side_effect(*args, **kwargs): - url = kwargs.get("url", "") - if "files" in url: - mock_response.json.return_value = mock_file_response - elif "batch" in url: - mock_response.json.return_value = mock_vertex_batch_response - mock_response.status_code = 200 - return mock_response - + mock_post.side_effect = mock_side_effect + mock_global_post.side_effect = mock_side_effect # load_vertex_ai_credentials() litellm.set_verbose = True @@ -455,6 +465,7 @@ async def test_avertex_batch_prediction(): file=open(file_path, "rb"), purpose="batch", custom_llm_provider="vertex_ai", + client=client ) print("Response from creating file=", file_obj) diff --git a/tests/llm_translation/test_huggingface_chat_completion.py b/tests/llm_translation/test_huggingface_chat_completion.py index 9f1e89aeb1..7d498b96df 100644 --- a/tests/llm_translation/test_huggingface_chat_completion.py +++ b/tests/llm_translation/test_huggingface_chat_completion.py @@ -323,7 +323,8 @@ class TestHuggingFace(BaseLLMChatTest): model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct", messages=[{"role": "user", "content": "Hello"}], optional_params={}, - api_key="test_api_key" + api_key="test_api_key", + litellm_params={} ) assert headers["Authorization"] == "Bearer test_api_key" diff --git a/tests/local_testing/example.jsonl b/tests/local_testing/example.jsonl new file mode 100644 index 0000000000..fc3ca40808 --- /dev/null +++ b/tests/local_testing/example.jsonl @@ -0,0 +1,2 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}} diff --git a/tests/local_testing/test_gcs_bucket.py b/tests/local_testing/test_gcs_bucket.py index 1a8deed8a8..b64475c227 100644 --- a/tests/local_testing/test_gcs_bucket.py +++ b/tests/local_testing/test_gcs_bucket.py @@ -21,7 +21,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket import ( StandardLoggingPayload, ) from litellm.types.utils import StandardCallbackDynamicParams - +from unittest.mock import patch verbose_logger.setLevel(logging.DEBUG) @@ -687,3 +687,63 @@ async def test_basic_gcs_logger_with_folder_in_bucket_name(): # clean up if old_bucket_name is not None: os.environ["GCS_BUCKET_NAME"] = old_bucket_name + +@pytest.mark.skip(reason="This test is flaky on ci/cd") +def test_create_file_e2e(): + """ + Asserts 'create_file' is called with the correct arguments + """ + load_vertex_ai_credentials() + test_file_content = b"test audio content" + test_file = ("test.wav", test_file_content, "audio/wav") + + from litellm import create_file + response = create_file( + file=test_file, + purpose="user_data", + custom_llm_provider="vertex_ai", + ) + print("response", response) + assert response is not None + +@pytest.mark.skip(reason="This test is flaky on ci/cd") +def test_create_file_e2e_jsonl(): + """ + Asserts 'create_file' is called with the correct arguments + """ + load_vertex_ai_credentials() + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + + example_jsonl = [{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}},{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}] + + # Create and write to the file + file_path = "example.jsonl" + with open(file_path, "w") as f: + for item in example_jsonl: + f.write(json.dumps(item) + "\n") + + # Verify file content + with open(file_path, "r") as f: + content = f.read() + print("File content:", content) + assert len(content) > 0, "File is empty" + + from litellm import create_file + with patch.object(client, "post") as mock_create_file: + try: + response = create_file( + file=open(file_path, "rb"), + purpose="user_data", + custom_llm_provider="vertex_ai", + client=client, + ) + except Exception as e: + print("error", e) + + mock_create_file.assert_called_once() + + print(f"kwargs: {mock_create_file.call_args.kwargs}") + + assert mock_create_file.call_args.kwargs["data"] is not None and len(mock_create_file.call_args.kwargs["data"]) > 0 \ No newline at end of file