diff --git a/.circleci/config.yml b/.circleci/config.yml index b1e08084fa..14a22a5995 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -610,6 +610,8 @@ jobs: name: Install Dependencies command: | python -m pip install --upgrade pip + pip install wheel + pip install --upgrade pip wheel setuptools python -m pip install -r requirements.txt pip install "pytest==7.3.1" pip install "respx==0.21.1" diff --git a/deploy/charts/litellm-helm/templates/service.yaml b/deploy/charts/litellm-helm/templates/service.yaml index 40e7f27f16..d8d81e78c8 100644 --- a/deploy/charts/litellm-helm/templates/service.yaml +++ b/deploy/charts/litellm-helm/templates/service.yaml @@ -2,6 +2,10 @@ apiVersion: v1 kind: Service metadata: name: {{ include "litellm.fullname" . }} + {{- with .Values.service.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} labels: {{- include "litellm.labels" . | nindent 4 }} spec: diff --git a/docs/my-website/docs/providers/gemini.md b/docs/my-website/docs/providers/gemini.md index 42783286f1..db63d33d8d 100644 --- a/docs/my-website/docs/providers/gemini.md +++ b/docs/my-website/docs/providers/gemini.md @@ -438,6 +438,179 @@ assert isinstance( ``` +### Google Search Tool + + + + +```python +from litellm import completion +import os + +os.environ["GEMINI_API_KEY"] = ".." + +tools = [{"googleSearch": {}}] # 👈 ADD GOOGLE SEARCH + +response = completion( + model="gemini/gemini-2.0-flash", + messages=[{"role": "user", "content": "What is the weather in San Francisco?"}], + tools=tools, +) + +print(response) +``` + + + + +1. Setup config.yaml +```yaml +model_list: + - model_name: gemini-2.0-flash + litellm_params: + model: gemini/gemini-2.0-flash + api_key: os.environ/GEMINI_API_KEY +``` + +2. Start Proxy +```bash +$ litellm --config /path/to/config.yaml +``` + +3. Make Request! +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "gemini-2.0-flash", + "messages": [{"role": "user", "content": "What is the weather in San Francisco?"}], + "tools": [{"googleSearch": {}}] +} +' +``` + + + + +### Google Search Retrieval + + + + + +```python +from litellm import completion +import os + +os.environ["GEMINI_API_KEY"] = ".." + +tools = [{"googleSearchRetrieval": {}}] # 👈 ADD GOOGLE SEARCH + +response = completion( + model="gemini/gemini-2.0-flash", + messages=[{"role": "user", "content": "What is the weather in San Francisco?"}], + tools=tools, +) + +print(response) +``` + + + + +1. Setup config.yaml +```yaml +model_list: + - model_name: gemini-2.0-flash + litellm_params: + model: gemini/gemini-2.0-flash + api_key: os.environ/GEMINI_API_KEY +``` + +2. Start Proxy +```bash +$ litellm --config /path/to/config.yaml +``` + +3. Make Request! +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "gemini-2.0-flash", + "messages": [{"role": "user", "content": "What is the weather in San Francisco?"}], + "tools": [{"googleSearchRetrieval": {}}] +} +' +``` + + + + + +### Code Execution Tool + + + + + +```python +from litellm import completion +import os + +os.environ["GEMINI_API_KEY"] = ".." + +tools = [{"codeExecution": {}}] # 👈 ADD GOOGLE SEARCH + +response = completion( + model="gemini/gemini-2.0-flash", + messages=[{"role": "user", "content": "What is the weather in San Francisco?"}], + tools=tools, +) + +print(response) +``` + + + + +1. Setup config.yaml +```yaml +model_list: + - model_name: gemini-2.0-flash + litellm_params: + model: gemini/gemini-2.0-flash + api_key: os.environ/GEMINI_API_KEY +``` + +2. Start Proxy +```bash +$ litellm --config /path/to/config.yaml +``` + +3. Make Request! +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "gemini-2.0-flash", + "messages": [{"role": "user", "content": "What is the weather in San Francisco?"}], + "tools": [{"codeExecution": {}}] +} +' +``` + + + + + + + + + ## JSON Mode diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index ab13a51137..cdd3fce6c6 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -398,6 +398,8 @@ curl http://localhost:4000/v1/chat/completions \ +You can also use the `enterpriseWebSearch` tool for an [enterprise compliant search](https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise). + #### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)** diff --git a/litellm/litellm_core_utils/get_litellm_params.py b/litellm/litellm_core_utils/get_litellm_params.py index 4f2f43f0de..f40f1ae4c7 100644 --- a/litellm/litellm_core_utils/get_litellm_params.py +++ b/litellm/litellm_core_utils/get_litellm_params.py @@ -110,5 +110,8 @@ def get_litellm_params( "azure_password": kwargs.get("azure_password"), "max_retries": max_retries, "timeout": kwargs.get("timeout"), + "bucket_name": kwargs.get("bucket_name"), + "vertex_credentials": kwargs.get("vertex_credentials"), + "vertex_project": kwargs.get("vertex_project"), } return litellm_params diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 0f2d0da388..44b680d487 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -2,7 +2,10 @@ Common utility functions used for translating messages across providers """ -from typing import Dict, List, Literal, Optional, Union, cast +import io +import mimetypes +from os import PathLike +from typing import Dict, List, Literal, Mapping, Optional, Union, cast from litellm.types.llms.openai import ( AllMessageValues, @@ -10,7 +13,13 @@ from litellm.types.llms.openai import ( ChatCompletionFileObject, ChatCompletionUserMessage, ) -from litellm.types.utils import Choices, ModelResponse, StreamingChoices +from litellm.types.utils import ( + Choices, + ExtractedFileData, + FileTypes, + ModelResponse, + StreamingChoices, +) DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage( content="Please continue.", role="user" @@ -350,6 +359,68 @@ def update_messages_with_model_file_ids( return messages +def extract_file_data(file_data: FileTypes) -> ExtractedFileData: + """ + Extracts and processes file data from various input formats. + + Args: + file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content + + Returns: + ExtractedFileData containing: + - filename: Name of the file if provided + - content: The file content in bytes + - content_type: MIME type of the file + - headers: Any additional headers + """ + # Parse the file_data based on its type + filename = None + file_content = None + content_type = None + file_headers: Mapping[str, str] = {} + + if isinstance(file_data, tuple): + if len(file_data) == 2: + filename, file_content = file_data + elif len(file_data) == 3: + filename, file_content, content_type = file_data + elif len(file_data) == 4: + filename, file_content, content_type, file_headers = file_data + else: + file_content = file_data + # Convert content to bytes + if isinstance(file_content, (str, PathLike)): + # If it's a path, open and read the file + with open(file_content, "rb") as f: + content = f.read() + elif isinstance(file_content, io.IOBase): + # If it's a file-like object + content = file_content.read() + + if isinstance(content, str): + content = content.encode("utf-8") + # Reset file pointer to beginning + file_content.seek(0) + elif isinstance(file_content, bytes): + content = file_content + else: + raise ValueError(f"Unsupported file content type: {type(file_content)}") + + # Use provided content type or guess based on filename + if not content_type: + content_type = ( + mimetypes.guess_type(filename)[0] + if filename + else "application/octet-stream" + ) + + return ExtractedFileData( + filename=filename, + content=content, + content_type=content_type, + headers=file_headers, + ) + def unpack_defs(schema, defs): properties = schema.get("properties", None) if properties is None: @@ -381,3 +452,4 @@ def unpack_defs(schema, defs): unpack_defs(ref, defs) value["items"] = ref continue + diff --git a/litellm/llms/aiohttp_openai/chat/transformation.py b/litellm/llms/aiohttp_openai/chat/transformation.py index af073fe8e3..c2d4e5adcd 100644 --- a/litellm/llms/aiohttp_openai/chat/transformation.py +++ b/litellm/llms/aiohttp_openai/chat/transformation.py @@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index f2a5542dcd..44567facf9 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -301,6 +301,7 @@ class AnthropicChatCompletion(BaseLLM): model=model, messages=messages, optional_params={**optional_params, "is_vertex_request": is_vertex_request}, + litellm_params=litellm_params, ) config = ProviderConfigManager.get_provider_chat_config( diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 8a2048f95a..9b66249630 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -868,6 +868,7 @@ class AnthropicConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py index e4e04df4d6..9e3287aa8a 100644 --- a/litellm/llms/anthropic/completion/transformation.py +++ b/litellm/llms/anthropic/completion/transformation.py @@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/azure/chat/gpt_transformation.py b/litellm/llms/azure/chat/gpt_transformation.py index e30d68f97d..ea61ef2c9a 100644 --- a/litellm/llms/azure/chat/gpt_transformation.py +++ b/litellm/llms/azure/chat/gpt_transformation.py @@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/azure_ai/chat/transformation.py b/litellm/llms/azure_ai/chat/transformation.py index 007a4303c8..839f875f75 100644 --- a/litellm/llms/azure_ai/chat/transformation.py +++ b/litellm/llms/azure_ai/chat/transformation.py @@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 5279a44201..fa278c805e 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -262,6 +262,7 @@ class BaseConfig(ABC): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/base_llm/files/transformation.py b/litellm/llms/base_llm/files/transformation.py index 0f1f46352f..9925004c89 100644 --- a/litellm/llms/base_llm/files/transformation.py +++ b/litellm/llms/base_llm/files/transformation.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union import httpx @@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig): ) -> List[OpenAICreateFileRequestOptionalParams]: pass - def get_complete_url( + def get_complete_file_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, - stream: Optional[bool] = None, - ) -> str: - """ - OPTIONAL - - Get the complete url for the request - - Some providers need `model` in `api_base` - """ - return api_base or "" + data: CreateFileRequest, + ): + return self.get_complete_url( + api_base=api_base, + api_key=api_key, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + ) @abstractmethod def transform_create_file_request( @@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig): create_file_data: CreateFileRequest, optional_params: dict, litellm_params: dict, - ) -> dict: + ) -> Union[dict, str, bytes]: pass @abstractmethod diff --git a/litellm/llms/base_llm/image_variations/transformation.py b/litellm/llms/base_llm/image_variations/transformation.py index 3ed446a84e..60444d0fb7 100644 --- a/litellm/llms/base_llm/image_variations/transformation.py +++ b/litellm/llms/base_llm/image_variations/transformation.py @@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 8ce2c4818b..fbe2dc4937 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -831,6 +831,7 @@ class AmazonConverseConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index cb12f779cc..67194e83e7 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/clarifai/chat/transformation.py b/litellm/llms/clarifai/chat/transformation.py index 916da73883..73be89fc6e 100644 --- a/litellm/llms/clarifai/chat/transformation.py +++ b/litellm/llms/clarifai/chat/transformation.py @@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/cloudflare/chat/transformation.py b/litellm/llms/cloudflare/chat/transformation.py index 1874bb5115..9e59782bf7 100644 --- a/litellm/llms/cloudflare/chat/transformation.py +++ b/litellm/llms/cloudflare/chat/transformation.py @@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py index 70677214a7..5dd44aca80 100644 --- a/litellm/llms/cohere/chat/transformation.py +++ b/litellm/llms/cohere/chat/transformation.py @@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/cohere/completion/transformation.py b/litellm/llms/cohere/completion/transformation.py index bdfcda020e..f96ef89d3c 100644 --- a/litellm/llms/cohere/completion/transformation.py +++ b/litellm/llms/cohere/completion/transformation.py @@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py index 72092cf261..13141fc19a 100644 --- a/litellm/llms/custom_httpx/aiohttp_handler.py +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -229,6 +229,7 @@ class BaseLLMAIOHTTPHandler: model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, api_base=api_base, ) @@ -498,6 +499,7 @@ class BaseLLMAIOHTTPHandler: model=model, messages=[{"role": "user", "content": "test"}], optional_params=optional_params, + litellm_params=litellm_params, api_base=api_base, ) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 23d7fe4b4d..f1aa5627dc 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -192,7 +192,7 @@ class AsyncHTTPHandler: async def post( self, url: str, - data: Optional[Union[dict, str]] = None, # type: ignore + data: Optional[Union[dict, str, bytes]] = None, # type: ignore json: Optional[dict] = None, params: Optional[dict] = None, headers: Optional[dict] = None, @@ -427,7 +427,7 @@ class AsyncHTTPHandler: self, url: str, client: httpx.AsyncClient, - data: Optional[Union[dict, str]] = None, # type: ignore + data: Optional[Union[dict, str, bytes]] = None, # type: ignore json: Optional[dict] = None, params: Optional[dict] = None, headers: Optional[dict] = None, @@ -527,7 +527,7 @@ class HTTPHandler: def post( self, url: str, - data: Optional[Union[dict, str]] = None, + data: Optional[Union[dict, str, bytes]] = None, json: Optional[Union[dict, str, List]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, @@ -573,7 +573,6 @@ class HTTPHandler: setattr(e, "text", error_text) setattr(e, "status_code", e.response.status_code) - raise e except Exception as e: raise e diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 5778f0228f..b7c72e89ef 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -247,6 +247,7 @@ class BaseLLMHTTPHandler: messages=messages, optional_params=optional_params, api_base=api_base, + litellm_params=litellm_params, ) api_base = provider_config.get_complete_url( @@ -625,6 +626,7 @@ class BaseLLMHTTPHandler: model=model, messages=[], optional_params=optional_params, + litellm_params=litellm_params, ) api_base = provider_config.get_complete_url( @@ -896,6 +898,7 @@ class BaseLLMHTTPHandler: model=model, messages=[], optional_params=optional_params, + litellm_params=litellm_params, ) if client is None or not isinstance(client, HTTPHandler): @@ -1228,15 +1231,19 @@ class BaseLLMHTTPHandler: model="", messages=[], optional_params={}, + litellm_params=litellm_params, ) - api_base = provider_config.get_complete_url( + api_base = provider_config.get_complete_file_url( api_base=api_base, api_key=api_key, model="", optional_params={}, litellm_params=litellm_params, + data=create_file_data, ) + if api_base is None: + raise ValueError("api_base is required for create_file") # Get the transformed request data for both steps transformed_request = provider_config.transform_create_file_request( @@ -1263,48 +1270,57 @@ class BaseLLMHTTPHandler: else: sync_httpx_client = client - try: - # Step 1: Initial request to get upload URL - initial_response = sync_httpx_client.post( - url=api_base, - headers={ - **headers, - **transformed_request["initial_request"]["headers"], - }, - data=json.dumps(transformed_request["initial_request"]["data"]), - timeout=timeout, - ) - - # Extract upload URL from response headers - upload_url = initial_response.headers.get("X-Goog-Upload-URL") - - if not upload_url: - raise ValueError("Failed to get upload URL from initial request") - - # Step 2: Upload the actual file + if isinstance(transformed_request, str) or isinstance( + transformed_request, bytes + ): upload_response = sync_httpx_client.post( - url=upload_url, - headers=transformed_request["upload_request"]["headers"], - data=transformed_request["upload_request"]["data"], + url=api_base, + headers=headers, + data=transformed_request, timeout=timeout, ) + else: + try: + # Step 1: Initial request to get upload URL + initial_response = sync_httpx_client.post( + url=api_base, + headers={ + **headers, + **transformed_request["initial_request"]["headers"], + }, + data=json.dumps(transformed_request["initial_request"]["data"]), + timeout=timeout, + ) - return provider_config.transform_create_file_response( - model=None, - raw_response=upload_response, - logging_obj=logging_obj, - litellm_params=litellm_params, - ) + # Extract upload URL from response headers + upload_url = initial_response.headers.get("X-Goog-Upload-URL") - except Exception as e: - raise self._handle_error( - e=e, - provider_config=provider_config, - ) + if not upload_url: + raise ValueError("Failed to get upload URL from initial request") + + # Step 2: Upload the actual file + upload_response = sync_httpx_client.post( + url=upload_url, + headers=transformed_request["upload_request"]["headers"], + data=transformed_request["upload_request"]["data"], + timeout=timeout, + ) + except Exception as e: + raise self._handle_error( + e=e, + provider_config=provider_config, + ) + + return provider_config.transform_create_file_response( + model=None, + raw_response=upload_response, + logging_obj=logging_obj, + litellm_params=litellm_params, + ) async def async_create_file( self, - transformed_request: dict, + transformed_request: Union[bytes, str, dict], litellm_params: dict, provider_config: BaseFilesConfig, headers: dict, @@ -1323,45 +1339,54 @@ class BaseLLMHTTPHandler: else: async_httpx_client = client - try: - # Step 1: Initial request to get upload URL - initial_response = await async_httpx_client.post( - url=api_base, - headers={ - **headers, - **transformed_request["initial_request"]["headers"], - }, - data=json.dumps(transformed_request["initial_request"]["data"]), - timeout=timeout, - ) - - # Extract upload URL from response headers - upload_url = initial_response.headers.get("X-Goog-Upload-URL") - - if not upload_url: - raise ValueError("Failed to get upload URL from initial request") - - # Step 2: Upload the actual file + if isinstance(transformed_request, str) or isinstance( + transformed_request, bytes + ): upload_response = await async_httpx_client.post( - url=upload_url, - headers=transformed_request["upload_request"]["headers"], - data=transformed_request["upload_request"]["data"], + url=api_base, + headers=headers, + data=transformed_request, timeout=timeout, ) + else: + try: + # Step 1: Initial request to get upload URL + initial_response = await async_httpx_client.post( + url=api_base, + headers={ + **headers, + **transformed_request["initial_request"]["headers"], + }, + data=json.dumps(transformed_request["initial_request"]["data"]), + timeout=timeout, + ) - return provider_config.transform_create_file_response( - model=None, - raw_response=upload_response, - logging_obj=logging_obj, - litellm_params=litellm_params, - ) + # Extract upload URL from response headers + upload_url = initial_response.headers.get("X-Goog-Upload-URL") - except Exception as e: - verbose_logger.exception(f"Error creating file: {e}") - raise self._handle_error( - e=e, - provider_config=provider_config, - ) + if not upload_url: + raise ValueError("Failed to get upload URL from initial request") + + # Step 2: Upload the actual file + upload_response = await async_httpx_client.post( + url=upload_url, + headers=transformed_request["upload_request"]["headers"], + data=transformed_request["upload_request"]["data"], + timeout=timeout, + ) + except Exception as e: + verbose_logger.exception(f"Error creating file: {e}") + raise self._handle_error( + e=e, + provider_config=provider_config, + ) + + return provider_config.transform_create_file_response( + model=None, + raw_response=upload_response, + logging_obj=logging_obj, + litellm_params=litellm_params, + ) def list_files(self): """ diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py index 1940f09608..6f5738fb4b 100644 --- a/litellm/llms/databricks/chat/transformation.py +++ b/litellm/llms/databricks/chat/transformation.py @@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/deepgram/audio_transcription/transformation.py b/litellm/llms/deepgram/audio_transcription/transformation.py index b4803576e0..f1b18808f7 100644 --- a/litellm/llms/deepgram/audio_transcription/transformation.py +++ b/litellm/llms/deepgram/audio_transcription/transformation.py @@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/fireworks_ai/common_utils.py b/litellm/llms/fireworks_ai/common_utils.py index 293403b133..17aa67b525 100644 --- a/litellm/llms/fireworks_ai/common_utils.py +++ b/litellm/llms/fireworks_ai/common_utils.py @@ -41,6 +41,7 @@ class FireworksAIMixin: model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/gemini/common_utils.py b/litellm/llms/gemini/common_utils.py index ace24e982f..fef41f7d58 100644 --- a/litellm/llms/gemini/common_utils.py +++ b/litellm/llms/gemini/common_utils.py @@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/gemini/files/transformation.py b/litellm/llms/gemini/files/transformation.py index a1f99c6903..e98e76dabc 100644 --- a/litellm/llms/gemini/files/transformation.py +++ b/litellm/llms/gemini/files/transformation.py @@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API. For vertex ai, check out the vertex_ai/files/handler.py file. """ import time -from typing import List, Mapping, Optional +from typing import List, Optional import httpx from litellm._logging import verbose_logger +from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data from litellm.llms.base_llm.files.transformation import ( BaseFilesConfig, LiteLLMLoggingObj, @@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig): if file_data is None: raise ValueError("File data is required") - # Parse the file_data based on its type - filename = None - file_content = None - content_type = None - file_headers: Mapping[str, str] = {} - - if isinstance(file_data, tuple): - if len(file_data) == 2: - filename, file_content = file_data - elif len(file_data) == 3: - filename, file_content, content_type = file_data - elif len(file_data) == 4: - filename, file_content, content_type, file_headers = file_data - else: - file_content = file_data - - # Handle the file content based on its type - import io - from os import PathLike - - # Convert content to bytes - if isinstance(file_content, (str, PathLike)): - # If it's a path, open and read the file - with open(file_content, "rb") as f: - content = f.read() - elif isinstance(file_content, io.IOBase): - # If it's a file-like object - content = file_content.read() - if isinstance(content, str): - content = content.encode("utf-8") - elif isinstance(file_content, bytes): - content = file_content - else: - raise ValueError(f"Unsupported file content type: {type(file_content)}") + # Use the common utility function to extract file data + extracted_data = extract_file_data(file_data) # Get file size - file_size = len(content) - - # Use provided content type or guess based on filename - if not content_type: - import mimetypes - - content_type = ( - mimetypes.guess_type(filename)[0] - if filename - else "application/octet-stream" - ) + file_size = len(extracted_data["content"]) # Step 1: Initial resumable upload request headers = { "X-Goog-Upload-Protocol": "resumable", "X-Goog-Upload-Command": "start", "X-Goog-Upload-Header-Content-Length": str(file_size), - "X-Goog-Upload-Header-Content-Type": content_type, + "X-Goog-Upload-Header-Content-Type": extracted_data["content_type"], "Content-Type": "application/json", } - headers.update(file_headers) # Add any custom headers + headers.update(extracted_data["headers"]) # Add any custom headers # Initial metadata request body - initial_data = {"file": {"display_name": filename or str(int(time.time()))}} + initial_data = { + "file": { + "display_name": extracted_data["filename"] or str(int(time.time())) + } + } # Step 2: Actual file upload data upload_headers = { @@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig): return { "initial_request": {"headers": headers, "data": initial_data}, - "upload_request": {"headers": upload_headers, "data": content}, + "upload_request": { + "headers": upload_headers, + "data": extracted_data["content"], + }, } def transform_create_file_response( diff --git a/litellm/llms/huggingface/chat/transformation.py b/litellm/llms/huggingface/chat/transformation.py index c84f03ab93..0ad93be763 100644 --- a/litellm/llms/huggingface/chat/transformation.py +++ b/litellm/llms/huggingface/chat/transformation.py @@ -1,6 +1,6 @@ import logging import os -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx @@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping - logger = logging.getLogger(__name__) BASE_URL = "https://router.huggingface.co" @@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig): headers: dict, model: str, messages: List[AllMessageValues], - optional_params: dict, + optional_params: Dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: @@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig): def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: - return HuggingFaceError(status_code=status_code, message=error_message, headers=headers) + return HuggingFaceError( + status_code=status_code, message=error_message, headers=headers + ) def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]: """ @@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig): if api_base is not None: complete_url = api_base elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"): - complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE")) + complete_url = str(os.getenv("HF_API_BASE")) or str( + os.getenv("HUGGINGFACE_API_BASE") + ) elif model.startswith(("http://", "https://")): complete_url = model # 4. Default construction with provider @@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig): ) mapped_model = provider_mapping["providerId"] messages = self._transform_messages(messages=messages, model=mapped_model) - return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params)) + return dict( + ChatCompletionRequest( + model=mapped_model, messages=messages, **optional_params + ) + ) diff --git a/litellm/llms/huggingface/embedding/handler.py b/litellm/llms/huggingface/embedding/handler.py index 7277fbd0e3..bfd73c1346 100644 --- a/litellm/llms/huggingface/embedding/handler.py +++ b/litellm/llms/huggingface/embedding/handler.py @@ -1,15 +1,6 @@ import json import os -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - Union, - get_args, -) +from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args import httpx @@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug ] - -def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]: +def get_hf_task_embedding_for_model( + model: str, task_type: Optional[str], api_base: str +) -> Optional[str]: if task_type is not None: if task_type in get_args(hf_tasks_embeddings): return task_type @@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba return pipeline_tag -async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]: +async def async_get_hf_task_embedding_for_model( + model: str, task_type: Optional[str], api_base: str +) -> Optional[str]: if task_type is not None: if task_type in get_args(hf_tasks_embeddings): return task_type @@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM): input: List, optional_params: dict, ) -> dict: - hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) + hf_task = await async_get_hf_task_embedding_for_model( + model=model, task_type=task_type, api_base=HF_HUB_URL + ) data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task) @@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM): task_type = optional_params.pop("input_type", None) if call_type == "sync": - hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) + hf_task = get_hf_task_embedding_for_model( + model=model, task_type=task_type, api_base=HF_HUB_URL + ) elif call_type == "async": return self._async_transform_input( model=model, task_type=task_type, embed_url=embed_url, input=input @@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM): input: list, model_response: EmbeddingResponse, optional_params: dict, + litellm_params: dict, logging_obj: LiteLLMLoggingObj, encoding: Callable, api_key: Optional[str] = None, @@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM): model=model, optional_params=optional_params, messages=[], + litellm_params=litellm_params, ) task_type = optional_params.pop("input_type", None) - task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) + task = get_hf_task_embedding_for_model( + model=model, task_type=task_type, api_base=HF_HUB_URL + ) # print_verbose(f"{model}, {task}") embed_url = "" if "https" in model: @@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM): elif "HUGGINGFACE_API_BASE" in os.environ: embed_url = os.getenv("HUGGINGFACE_API_BASE", "") else: - embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}" + embed_url = ( + f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}" + ) ## ROUTING ## if aembedding is True: diff --git a/litellm/llms/huggingface/embedding/transformation.py b/litellm/llms/huggingface/embedding/transformation.py index f803157768..60bd5dcd61 100644 --- a/litellm/llms/huggingface/embedding/transformation.py +++ b/litellm/llms/huggingface/embedding/transformation.py @@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: Dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: diff --git a/litellm/llms/nlp_cloud/chat/handler.py b/litellm/llms/nlp_cloud/chat/handler.py index b0abdda587..b0563d8b55 100644 --- a/litellm/llms/nlp_cloud/chat/handler.py +++ b/litellm/llms/nlp_cloud/chat/handler.py @@ -36,6 +36,7 @@ def completion( model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) ## Load Config diff --git a/litellm/llms/nlp_cloud/chat/transformation.py b/litellm/llms/nlp_cloud/chat/transformation.py index b7967249ab..8037a45832 100644 --- a/litellm/llms/nlp_cloud/chat/transformation.py +++ b/litellm/llms/nlp_cloud/chat/transformation.py @@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index 64544bd269..789b728337 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/oobabooga/chat/oobabooga.py b/litellm/llms/oobabooga/chat/oobabooga.py index 8829d2233e..5eb68a03d4 100644 --- a/litellm/llms/oobabooga/chat/oobabooga.py +++ b/litellm/llms/oobabooga/chat/oobabooga.py @@ -32,6 +32,7 @@ def completion( model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) if "https" in model: completion_url = model @@ -123,6 +124,7 @@ def embedding( model=model, messages=[], optional_params=optional_params, + litellm_params={}, ) response = litellm.module_level_client.post( embeddings_url, headers=headers, json=data diff --git a/litellm/llms/oobabooga/chat/transformation.py b/litellm/llms/oobabooga/chat/transformation.py index 6fd56f934e..e87b70130c 100644 --- a/litellm/llms/oobabooga/chat/transformation.py +++ b/litellm/llms/oobabooga/chat/transformation.py @@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index fcab43901a..434214639e 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 3b6be1a034..13412ef96a 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/openai/transcriptions/whisper_transformation.py b/litellm/llms/openai/transcriptions/whisper_transformation.py index 2d3d611dac..c0ccc71579 100644 --- a/litellm/llms/openai/transcriptions/whisper_transformation.py +++ b/litellm/llms/openai/transcriptions/whisper_transformation.py @@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py index a9e37d27fc..24910cba8f 100644 --- a/litellm/llms/petals/completion/transformation.py +++ b/litellm/llms/petals/completion/transformation.py @@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/predibase/chat/handler.py b/litellm/llms/predibase/chat/handler.py index cd80fa53e4..79936764ac 100644 --- a/litellm/llms/predibase/chat/handler.py +++ b/litellm/llms/predibase/chat/handler.py @@ -228,10 +228,10 @@ class PredibaseChatCompletion: api_key: str, logging_obj, optional_params: dict, + litellm_params: dict, tenant_id: str, timeout: Union[float, httpx.Timeout], acompletion=None, - litellm_params=None, logger_fn=None, headers: dict = {}, ) -> Union[ModelResponse, CustomStreamWrapper]: @@ -241,6 +241,7 @@ class PredibaseChatCompletion: messages=messages, optional_params=optional_params, model=model, + litellm_params=litellm_params, ) completion_url = "" input_text = "" diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py index 8ef0eea173..9fbb9d6c9e 100644 --- a/litellm/llms/predibase/chat/transformation.py +++ b/litellm/llms/predibase/chat/transformation.py @@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/replicate/chat/handler.py b/litellm/llms/replicate/chat/handler.py index d954416381..e4bb64fed7 100644 --- a/litellm/llms/replicate/chat/handler.py +++ b/litellm/llms/replicate/chat/handler.py @@ -141,6 +141,7 @@ def completion( model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) # Start a prediction and get the prediction URL version_id = replicate_config.model_to_version_id(model) diff --git a/litellm/llms/replicate/chat/transformation.py b/litellm/llms/replicate/chat/transformation.py index 604e6eefe6..4c61086801 100644 --- a/litellm/llms/replicate/chat/transformation.py +++ b/litellm/llms/replicate/chat/transformation.py @@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/sagemaker/completion/handler.py b/litellm/llms/sagemaker/completion/handler.py index 296689c31c..ebd96ac5b1 100644 --- a/litellm/llms/sagemaker/completion/handler.py +++ b/litellm/llms/sagemaker/completion/handler.py @@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM): model: str, data: dict, messages: List[AllMessageValues], + litellm_params: dict, optional_params: dict, aws_region_name: str, extra_headers: Optional[dict] = None, @@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) request = AWSRequest( method="POST", url=api_base, data=encoded_data, headers=headers @@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM): data=data, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, credentials=credentials, aws_region_name=aws_region_name, ) @@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM): "model": model, "data": _data, "optional_params": optional_params, + "litellm_params": litellm_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, @@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM): "model": model, "data": data, "optional_params": optional_params, + "litellm_params": litellm_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, @@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM): "model": model, "data": data, "optional_params": optional_params, + "litellm_params": litellm_params, "credentials": credentials, "aws_region_name": aws_region_name, "messages": messages, diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index df3d028c99..bfc0b6e5f6 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/snowflake/chat/transformation.py b/litellm/llms/snowflake/chat/transformation.py index 574c4704cd..2b92911b05 100644 --- a/litellm/llms/snowflake/chat/transformation.py +++ b/litellm/llms/snowflake/chat/transformation.py @@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/topaz/image_variations/transformation.py b/litellm/llms/topaz/image_variations/transformation.py index 4d14f1ad24..afbd89b9bc 100644 --- a/litellm/llms/topaz/image_variations/transformation.py +++ b/litellm/llms/topaz/image_variations/transformation.py @@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py index 49126917f2..21fcf2eefb 100644 --- a/litellm/llms/triton/completion/transformation.py +++ b/litellm/llms/triton/completion/transformation.py @@ -48,6 +48,7 @@ class TritonConfig(BaseConfig): model: str, messages: List[AllMessageValues], optional_params: Dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: diff --git a/litellm/llms/triton/embedding/transformation.py b/litellm/llms/triton/embedding/transformation.py index 4744ec0834..8ab0277e36 100644 --- a/litellm/llms/triton/embedding/transformation.py +++ b/litellm/llms/triton/embedding/transformation.py @@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/vertex_ai/files/handler.py b/litellm/llms/vertex_ai/files/handler.py index 87c1cb8320..a666a2c37f 100644 --- a/litellm/llms/vertex_ai/files/handler.py +++ b/litellm/llms/vertex_ai/files/handler.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Coroutine, Optional, Union import httpx @@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES -from .transformation import VertexAIFilesTransformation +from .transformation import VertexAIJsonlFilesTransformation -vertex_ai_files_transformation = VertexAIFilesTransformation() +vertex_ai_files_transformation = VertexAIJsonlFilesTransformation() class VertexAIFilesHandler(GCSBucketBase): @@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase): timeout=timeout, max_retries=max_retries, ) - - return None # type: ignore + else: + return asyncio.run( + self.async_create_file( + create_file_data=create_file_data, + api_base=api_base, + vertex_credentials=vertex_credentials, + vertex_project=vertex_project, + vertex_location=vertex_location, + timeout=timeout, + max_retries=max_retries, + ) + ) diff --git a/litellm/llms/vertex_ai/files/transformation.py b/litellm/llms/vertex_ai/files/transformation.py index 89c6ff9deb..c795367e48 100644 --- a/litellm/llms/vertex_ai/files/transformation.py +++ b/litellm/llms/vertex_ai/files/transformation.py @@ -1,7 +1,17 @@ import json +import os +import time import uuid from typing import Any, Dict, List, Optional, Tuple, Union +from httpx import Headers, Response + +from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.files.transformation import ( + BaseFilesConfig, + LiteLLMLoggingObj, +) from litellm.llms.vertex_ai.common_utils import ( _convert_vertex_datetime_to_openai_datetime, ) @@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( VertexGeminiConfig, ) from litellm.types.llms.openai import ( + AllMessageValues, CreateFileRequest, FileTypes, + OpenAICreateFileRequestOptionalParams, OpenAIFileObject, PathLike, ) +from litellm.types.llms.vertex_ai import GcsBucketResponse +from litellm.types.utils import ExtractedFileData, LlmProviders + +from ..common_utils import VertexAIError +from ..vertex_llm_base import VertexBase -class VertexAIFilesTransformation(VertexGeminiConfig): +class VertexAIFilesConfig(VertexBase, BaseFilesConfig): + """ + Config for VertexAI Files + """ + + def __init__(self): + self.jsonl_transformation = VertexAIJsonlFilesTransformation() + super().__init__() + + @property + def custom_llm_provider(self) -> LlmProviders: + return LlmProviders.VERTEX_AI + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if not api_key: + api_key, _ = self.get_access_token( + credentials=litellm_params.get("vertex_credentials"), + project_id=litellm_params.get("vertex_project"), + ) + if not api_key: + raise ValueError("api_key is required") + headers["Authorization"] = f"Bearer {api_key}" + return headers + + def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str: + """ + Helper to extract content from various OpenAI file types and return as string. + + Handles: + - Direct content (str, bytes, IO[bytes]) + - Tuple formats: (filename, content, [content_type], [headers]) + - PathLike objects + """ + content: Union[str, bytes] = b"" + # Extract file content from tuple if necessary + if isinstance(openai_file_content, tuple): + # Take the second element which is always the file content + file_content = openai_file_content[1] + else: + file_content = openai_file_content + + # Handle different file content types + if isinstance(file_content, str): + # String content can be used directly + content = file_content + elif isinstance(file_content, bytes): + # Bytes content can be decoded + content = file_content + elif isinstance(file_content, PathLike): # PathLike + with open(str(file_content), "rb") as f: + content = f.read() + elif hasattr(file_content, "read"): # IO[bytes] + # File-like objects need to be read + content = file_content.read() + + # Ensure content is string + if isinstance(content, bytes): + content = content.decode("utf-8") + + return content + + def _get_gcs_object_name_from_batch_jsonl( + self, + openai_jsonl_content: List[Dict[str, Any]], + ) -> str: + """ + Gets a unique GCS object name for the VertexAI batch prediction job + + named as: litellm-vertex-{model}-{uuid} + """ + _model = openai_jsonl_content[0].get("body", {}).get("model", "") + if "publishers/google/models" not in _model: + _model = f"publishers/google/models/{_model}" + object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}" + return object_name + + def get_object_name( + self, extracted_file_data: ExtractedFileData, purpose: str + ) -> str: + """ + Get the object name for the request + """ + extracted_file_data_content = extracted_file_data.get("content") + + if extracted_file_data_content is None: + raise ValueError("file content is required") + + if purpose == "batch": + ## 1. If jsonl, check if there's a model name + file_content = self._get_content_from_openai_file( + extracted_file_data_content + ) + + # Split into lines and parse each line as JSON + openai_jsonl_content = [ + json.loads(line) for line in file_content.splitlines() if line.strip() + ] + if len(openai_jsonl_content) > 0: + return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content) + + ## 2. If not jsonl, return the filename + filename = extracted_file_data.get("filename") + if filename: + return filename + ## 3. If no file name, return timestamp + return str(int(time.time())) + + def get_complete_file_url( + self, + api_base: Optional[str], + api_key: Optional[str], + model: str, + optional_params: Dict, + litellm_params: Dict, + data: CreateFileRequest, + ) -> str: + """ + Get the complete url for the request + """ + bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME") + if not bucket_name: + raise ValueError("GCS bucket_name is required") + file_data = data.get("file") + purpose = data.get("purpose") + if file_data is None: + raise ValueError("file is required") + if purpose is None: + raise ValueError("purpose is required") + extracted_file_data = extract_file_data(file_data) + object_name = self.get_object_name(extracted_file_data, purpose) + endpoint = ( + f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}" + ) + api_base = api_base or "https://storage.googleapis.com" + if not api_base: + raise ValueError("api_base is required") + + return f"{api_base}/{endpoint}" + + def get_supported_openai_params( + self, model: str + ) -> List[OpenAICreateFileRequestOptionalParams]: + return [] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + return optional_params + + def _map_openai_to_vertex_params( + self, + openai_request_body: Dict[str, Any], + ) -> Dict[str, Any]: + """ + wrapper to call VertexGeminiConfig.map_openai_params + """ + from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( + VertexGeminiConfig, + ) + + config = VertexGeminiConfig() + _model = openai_request_body.get("model", "") + vertex_params = config.map_openai_params( + model=_model, + non_default_params=openai_request_body, + optional_params={}, + drop_params=False, + ) + return vertex_params + + def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content( + self, openai_jsonl_content: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Transforms OpenAI JSONL content to VertexAI JSONL content + + jsonl body for vertex is {"request": } + Example Vertex jsonl + {"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}} + {"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}} + """ + + vertex_jsonl_content = [] + for _openai_jsonl_content in openai_jsonl_content: + openai_request_body = _openai_jsonl_content.get("body") or {} + vertex_request_body = _transform_request_body( + messages=openai_request_body.get("messages", []), + model=openai_request_body.get("model", ""), + optional_params=self._map_openai_to_vertex_params(openai_request_body), + custom_llm_provider="vertex_ai", + litellm_params={}, + cached_content=None, + ) + vertex_jsonl_content.append({"request": vertex_request_body}) + return vertex_jsonl_content + + def transform_create_file_request( + self, + model: str, + create_file_data: CreateFileRequest, + optional_params: dict, + litellm_params: dict, + ) -> Union[bytes, str, dict]: + """ + 2 Cases: + 1. Handle basic file upload + 2. Handle batch file upload (.jsonl) + """ + file_data = create_file_data.get("file") + if file_data is None: + raise ValueError("file is required") + extracted_file_data = extract_file_data(file_data) + extracted_file_data_content = extracted_file_data.get("content") + if ( + create_file_data.get("purpose") == "batch" + and extracted_file_data.get("content_type") == "application/jsonl" + and extracted_file_data_content is not None + ): + ## 1. If jsonl, check if there's a model name + file_content = self._get_content_from_openai_file( + extracted_file_data_content + ) + + # Split into lines and parse each line as JSON + openai_jsonl_content = [ + json.loads(line) for line in file_content.splitlines() if line.strip() + ] + vertex_jsonl_content = ( + self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content( + openai_jsonl_content + ) + ) + return json.dumps(vertex_jsonl_content) + elif isinstance(extracted_file_data_content, bytes): + return extracted_file_data_content + else: + raise ValueError("Unsupported file content type") + + def transform_create_file_response( + self, + model: Optional[str], + raw_response: Response, + logging_obj: LiteLLMLoggingObj, + litellm_params: dict, + ) -> OpenAIFileObject: + """ + Transform VertexAI File upload response into OpenAI-style FileObject + """ + response_json = raw_response.json() + + try: + response_object = GcsBucketResponse(**response_json) # type: ignore + except Exception as e: + raise VertexAIError( + status_code=raw_response.status_code, + message=f"Error reading GCS bucket response: {e}", + headers=raw_response.headers, + ) + + gcs_id = response_object.get("id", "") + # Remove the last numeric ID from the path + gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else "" + + return OpenAIFileObject( + purpose=response_object.get("purpose", "batch"), + id=f"gs://{gcs_id}", + filename=response_object.get("name", ""), + created_at=_convert_vertex_datetime_to_openai_datetime( + vertex_datetime=response_object.get("timeCreated", "") + ), + status="uploaded", + bytes=int(response_object.get("size", 0)), + object="file", + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, Headers] + ) -> BaseLLMException: + return VertexAIError( + status_code=status_code, message=error_message, headers=headers + ) + + +class VertexAIJsonlFilesTransformation(VertexGeminiConfig): """ Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests """ diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 749b6d9428..b3b7857ea1 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -240,6 +240,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): gtool_func_declarations = [] googleSearch: Optional[dict] = None googleSearchRetrieval: Optional[dict] = None + enterpriseWebSearch: Optional[dict] = None code_execution: Optional[dict] = None # remove 'additionalProperties' from tools value = _remove_additional_properties(value) @@ -273,6 +274,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): googleSearch = tool["googleSearch"] elif tool.get("googleSearchRetrieval", None) is not None: googleSearchRetrieval = tool["googleSearchRetrieval"] + elif tool.get("enterpriseWebSearch", None) is not None: + enterpriseWebSearch = tool["enterpriseWebSearch"] elif tool.get("code_execution", None) is not None: code_execution = tool["code_execution"] elif openai_function_object is not None: @@ -299,6 +302,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): _tools["googleSearch"] = googleSearch if googleSearchRetrieval is not None: _tools["googleSearchRetrieval"] = googleSearchRetrieval + if enterpriseWebSearch is not None: + _tools["enterpriseWebSearch"] = enterpriseWebSearch if code_execution is not None: _tools["code_execution"] = code_execution return [_tools] @@ -900,6 +905,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): model: str, messages: List[AllMessageValues], optional_params: Dict, + litellm_params: Dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: @@ -1017,7 +1023,7 @@ class VertexLLM(VertexBase): logging_obj, stream, optional_params: dict, - litellm_params=None, + litellm_params: dict, logger_fn=None, api_base: Optional[str] = None, client: Optional[AsyncHTTPHandler] = None, @@ -1058,6 +1064,7 @@ class VertexLLM(VertexBase): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) ## LOGGING @@ -1144,6 +1151,7 @@ class VertexLLM(VertexBase): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) request_body = await async_transform_request_body(**data) # type: ignore @@ -1317,6 +1325,7 @@ class VertexLLM(VertexBase): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, ) ## TRANSFORMATION ## diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py index 88d7339449..8aebd83cc4 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py @@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM): optional_params=optional_params, api_key=auth_header, api_base=api_base, + litellm_params=litellm_params, ) ## LOGGING diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py index afa58c7e5c..5bf02ad765 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py @@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/vertex_ai/vertex_llm_base.py b/litellm/llms/vertex_ai/vertex_llm_base.py index 994e46b50b..8f3037c791 100644 --- a/litellm/llms/vertex_ai/vertex_llm_base.py +++ b/litellm/llms/vertex_ai/vertex_llm_base.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple from litellm._logging import verbose_logger from litellm.litellm_core_utils.asyncify import asyncify -from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES @@ -22,7 +21,7 @@ else: GoogleCredentialsObject = Any -class VertexBase(BaseLLM): +class VertexBase: def __init__(self) -> None: super().__init__() self.access_token: Optional[str] = None diff --git a/litellm/llms/voyage/embedding/transformation.py b/litellm/llms/voyage/embedding/transformation.py index df6ef91a41..91811e0392 100644 --- a/litellm/llms/voyage/embedding/transformation.py +++ b/litellm/llms/voyage/embedding/transformation.py @@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig): model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: diff --git a/litellm/llms/watsonx/chat/handler.py b/litellm/llms/watsonx/chat/handler.py index aeb0167595..45378c5529 100644 --- a/litellm/llms/watsonx/chat/handler.py +++ b/litellm/llms/watsonx/chat/handler.py @@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler): messages=messages, optional_params=optional_params, api_key=api_key, + litellm_params=litellm_params, ) ## UPDATE PAYLOAD (optional params) diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index 4916cd1c75..e13e015add 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -165,6 +165,7 @@ class IBMWatsonXMixin: model: str, messages: List[AllMessageValues], optional_params: Dict, + litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Dict: diff --git a/litellm/main.py b/litellm/main.py index cd7d255e21..3f1d9a1e76 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915 optional_params=optional_params, client=client, aembedding=aembedding, + litellm_params=litellm_params_dict, ) elif custom_llm_provider == "bedrock": if isinstance(input, str): diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 79aa57f466..7e5be4dc6b 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -2409,25 +2409,26 @@ "max_tokens": 4096, "max_input_tokens": 131072, "max_output_tokens": 4096, - "input_cost_per_token": 0, - "output_cost_per_token": 0, + "input_cost_per_token": 0.000000075, + "output_cost_per_token": 0.0000003, "litellm_provider": "azure_ai", "mode": "chat", "supports_function_calling": true, - "source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft" + "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112" }, "azure_ai/Phi-4-multimodal-instruct": { "max_tokens": 4096, "max_input_tokens": 131072, "max_output_tokens": 4096, - "input_cost_per_token": 0, - "output_cost_per_token": 0, + "input_cost_per_token": 0.00000008, + "input_cost_per_audio_token": 0.000004, + "output_cost_per_token": 0.00032, "litellm_provider": "azure_ai", "mode": "chat", "supports_audio_input": true, "supports_function_calling": true, "supports_vision": true, - "source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft" + "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112" }, "azure_ai/Phi-4": { "max_tokens": 16384, @@ -3467,7 +3468,7 @@ "input_cost_per_token": 0.0000008, "output_cost_per_token": 0.000004, "cache_creation_input_token_cost": 0.000001, - "cache_read_input_token_cost": 0.0000008, + "cache_read_input_token_cost": 0.00000008, "litellm_provider": "anthropic", "mode": "chat", "supports_function_calling": true, diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ae4bdc7b8c..b64ff6c827 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1625,6 +1625,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase): model_max_budget: Optional[Dict] = {} model_spend: Optional[Dict] = {} user_email: Optional[str] = None + user_alias: Optional[str] = None models: list = [] tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index a6b1b3e614..563d0cb543 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -4,16 +4,26 @@ import json import uuid from base64 import b64encode from datetime import datetime -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from urllib.parse import parse_qs, urlencode, urlparse import httpx -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + Response, + UploadFile, + status, +) from fastapi.responses import StreamingResponse +from starlette.datastructures import UploadFile as StarletteUploadFile import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy._types import ( ConfigFieldInfo, @@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers: ) return response + @staticmethod + async def non_streaming_http_request_handler( + request: Request, + async_client: httpx.AsyncClient, + url: httpx.URL, + headers: dict, + requested_query_params: Optional[dict] = None, + _parsed_body: Optional[dict] = None, + ) -> httpx.Response: + """ + Handle non-streaming HTTP requests + + Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests + """ + if request.method == "GET": + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + ) + elif HttpPassThroughEndpointHelpers.is_multipart(request) is True: + return await HttpPassThroughEndpointHelpers.make_multipart_http_request( + request=request, + async_client=async_client, + url=url, + headers=headers, + requested_query_params=requested_query_params, + ) + else: + # Generic httpx method + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=_parsed_body, + ) + return response + + @staticmethod + def is_multipart(request: Request) -> bool: + """Check if the request is a multipart/form-data request""" + return "multipart/form-data" in request.headers.get("content-type", "") + + @staticmethod + async def _build_request_files_from_upload_file( + upload_file: Union[UploadFile, StarletteUploadFile], + ) -> Tuple[Optional[str], bytes, Optional[str]]: + """Build a request files dict from an UploadFile object""" + file_content = await upload_file.read() + return (upload_file.filename, file_content, upload_file.content_type) + + @staticmethod + async def make_multipart_http_request( + request: Request, + async_client: httpx.AsyncClient, + url: httpx.URL, + headers: dict, + requested_query_params: Optional[dict] = None, + ) -> httpx.Response: + """Process multipart/form-data requests, handling both files and form fields""" + form_data = await request.form() + files = {} + form_data_dict = {} + + for field_name, field_value in form_data.items(): + if isinstance(field_value, (StarletteUploadFile, UploadFile)): + files[field_name] = ( + await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + upload_file=field_value + ) + ) + else: + form_data_dict[field_name] = field_value + + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + files=files, + data=form_data_dict, + ) + return response + async def pass_through_request( # noqa: PLR0915 request: Request, @@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915 start_time = datetime.now() logging_obj = Logging( model="unknown", - messages=[{"role": "user", "content": json.dumps(_parsed_body)}], + messages=[{"role": "user", "content": safe_dumps(_parsed_body)}], stream=False, call_type="pass_through_endpoint", start_time=start_time, @@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915 logging_obj.model_call_details["litellm_call_id"] = litellm_call_id # combine url with query params for logging - requested_query_params: Optional[dict] = ( query_params or request.query_params.__dict__ ) @@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915 logging_url = str(url) + "?" + requested_query_params_str logging_obj.pre_call( - input=[{"role": "user", "content": json.dumps(_parsed_body)}], + input=[{"role": "user", "content": safe_dumps(_parsed_body)}], api_key="", additional_args={ "complete_input_dict": _parsed_body, @@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915 ) verbose_proxy_logger.debug("request body: {}".format(_parsed_body)) - if request.method == "GET": - response = await async_client.request( - method=request.method, + response = ( + await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler( + request=request, + async_client=async_client, url=url, headers=headers, - params=requested_query_params, + requested_query_params=requested_query_params, + _parsed_body=_parsed_body, ) - else: - response = await async_client.request( - method=request.method, - url=url, - headers=headers, - params=requested_query_params, - json=_parsed_body, - ) - + ) verbose_proxy_logger.debug("response.headers= %s", response.headers) if _is_streaming_response(response) is True: diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 7fa167938f..55273371fc 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -187,6 +187,7 @@ class Tools(TypedDict, total=False): function_declarations: List[FunctionDeclaration] googleSearch: dict googleSearchRetrieval: dict + enterpriseWebSearch: dict code_execution: dict retrieval: Retrieval @@ -497,6 +498,51 @@ class OutputConfig(TypedDict, total=False): gcsDestination: GcsDestination +class GcsBucketResponse(TypedDict): + """ + TypedDict for GCS bucket upload response + + Attributes: + kind: The kind of item this is. For objects, this is always storage#object + id: The ID of the object + selfLink: The link to this object + mediaLink: The link to download the object + name: The name of the object + bucket: The name of the bucket containing this object + generation: The content generation of this object + metageneration: The metadata generation of this object + contentType: The content type of the object + storageClass: The storage class of the object + size: The size of the object in bytes + md5Hash: The MD5 hash of the object + crc32c: The CRC32c checksum of the object + etag: The ETag of the object + timeCreated: The creation time of the object + updated: The last update time of the object + timeStorageClassUpdated: The time the storage class was last updated + timeFinalized: The time the object was finalized + """ + + kind: Literal["storage#object"] + id: str + selfLink: str + mediaLink: str + name: str + bucket: str + generation: str + metageneration: str + contentType: str + storageClass: str + size: str + md5Hash: str + crc32c: str + etag: str + timeCreated: str + updated: str + timeStorageClassUpdated: str + timeFinalized: str + + class VertexAIBatchPredictionJob(TypedDict): displayName: str model: str diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 8439037758..6f0c26d301 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -2,7 +2,7 @@ import json import time import uuid from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union from aiohttp import FormData from openai._models import BaseModel as OpenAIObject @@ -2170,3 +2170,20 @@ class CreateCredentialItem(CredentialBase): if not values.get("credential_values") and not values.get("model_id"): raise ValueError("Either credential_values or model_id must be set") return values + + +class ExtractedFileData(TypedDict): + """ + TypedDict for storing processed file data + + Attributes: + filename: Name of the file if provided + content: The file content in bytes + content_type: MIME type of the file + headers: Any additional headers for the file + """ + + filename: Optional[str] + content: bytes + content_type: Optional[str] + headers: Mapping[str, str] diff --git a/litellm/utils.py b/litellm/utils.py index f807990f60..f809d8a77b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6517,6 +6517,10 @@ class ProviderConfigManager: ) return GoogleAIStudioFilesHandler() + elif LlmProviders.VERTEX_AI == provider: + from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig + + return VertexAIFilesConfig() return None diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 79aa57f466..7e5be4dc6b 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -2409,25 +2409,26 @@ "max_tokens": 4096, "max_input_tokens": 131072, "max_output_tokens": 4096, - "input_cost_per_token": 0, - "output_cost_per_token": 0, + "input_cost_per_token": 0.000000075, + "output_cost_per_token": 0.0000003, "litellm_provider": "azure_ai", "mode": "chat", "supports_function_calling": true, - "source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft" + "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112" }, "azure_ai/Phi-4-multimodal-instruct": { "max_tokens": 4096, "max_input_tokens": 131072, "max_output_tokens": 4096, - "input_cost_per_token": 0, - "output_cost_per_token": 0, + "input_cost_per_token": 0.00000008, + "input_cost_per_audio_token": 0.000004, + "output_cost_per_token": 0.00032, "litellm_provider": "azure_ai", "mode": "chat", "supports_audio_input": true, "supports_function_calling": true, "supports_vision": true, - "source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft" + "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112" }, "azure_ai/Phi-4": { "max_tokens": 16384, @@ -3467,7 +3468,7 @@ "input_cost_per_token": 0.0000008, "output_cost_per_token": 0.000004, "cache_creation_input_token_cost": 0.000001, - "cache_read_input_token_cost": 0.0000008, + "cache_read_input_token_cost": 0.00000008, "litellm_provider": "anthropic", "mode": "chat", "supports_function_calling": true, diff --git a/tests/batches_tests/test_openai_batches_and_files.py b/tests/batches_tests/test_openai_batches_and_files.py index 4669a2def6..b2826419e8 100644 --- a/tests/batches_tests/test_openai_batches_and_files.py +++ b/tests/batches_tests/test_openai_batches_and_files.py @@ -423,25 +423,35 @@ mock_vertex_batch_response = { @pytest.mark.asyncio -async def test_avertex_batch_prediction(): - with patch( +async def test_avertex_batch_prediction(monkeypatch): + monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local") + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + + client = AsyncHTTPHandler() + + async def mock_side_effect(*args, **kwargs): + print("args", args, "kwargs", kwargs) + url = kwargs.get("url", "") + if "files" in url: + mock_response.json.return_value = mock_file_response + elif "batch" in url: + mock_response.json.return_value = mock_vertex_batch_response + mock_response.status_code = 200 + return mock_response + + with patch.object( + client, "post", side_effect=mock_side_effect + ) as mock_post, patch( "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post" - ) as mock_post: + ) as mock_global_post: # Configure mock responses mock_response = MagicMock() mock_response.raise_for_status.return_value = None # Set up different responses for different API calls - async def mock_side_effect(*args, **kwargs): - url = kwargs.get("url", "") - if "files" in url: - mock_response.json.return_value = mock_file_response - elif "batch" in url: - mock_response.json.return_value = mock_vertex_batch_response - mock_response.status_code = 200 - return mock_response - + mock_post.side_effect = mock_side_effect + mock_global_post.side_effect = mock_side_effect # load_vertex_ai_credentials() litellm.set_verbose = True @@ -455,6 +465,7 @@ async def test_avertex_batch_prediction(): file=open(file_path, "rb"), purpose="batch", custom_llm_provider="vertex_ai", + client=client ) print("Response from creating file=", file_obj) diff --git a/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py b/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py new file mode 100644 index 0000000000..43d4dd9cd8 --- /dev/null +++ b/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py @@ -0,0 +1,116 @@ +import json +import os +import sys +from io import BytesIO +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from fastapi import Request, UploadFile +from fastapi.testclient import TestClient +from starlette.datastructures import Headers +from starlette.datastructures import UploadFile as StarletteUploadFile + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + HttpPassThroughEndpointHelpers, +) + + +# Test is_multipart +def test_is_multipart(): + # Test with multipart content type + request = MagicMock(spec=Request) + request.headers = Headers({"content-type": "multipart/form-data; boundary=123"}) + assert HttpPassThroughEndpointHelpers.is_multipart(request) is True + + # Test with non-multipart content type + request.headers = Headers({"content-type": "application/json"}) + assert HttpPassThroughEndpointHelpers.is_multipart(request) is False + + # Test with no content type + request.headers = Headers({}) + assert HttpPassThroughEndpointHelpers.is_multipart(request) is False + + +# Test _build_request_files_from_upload_file +@pytest.mark.asyncio +async def test_build_request_files_from_upload_file(): + # Test with FastAPI UploadFile + file_content = b"test content" + file = BytesIO(file_content) + # Create SpooledTemporaryFile with content type headers + headers = {"content-type": "text/plain"} + upload_file = UploadFile(file=file, filename="test.txt", headers=headers) + upload_file.read = AsyncMock(return_value=file_content) + + result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + upload_file + ) + assert result == ("test.txt", file_content, "text/plain") + + # Test with Starlette UploadFile + file2 = BytesIO(file_content) + starlette_file = StarletteUploadFile( + file=file2, + filename="test2.txt", + headers=Headers({"content-type": "text/plain"}), + ) + starlette_file.read = AsyncMock(return_value=file_content) + + result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + starlette_file + ) + assert result == ("test2.txt", file_content, "text/plain") + + +# Test make_multipart_http_request +@pytest.mark.asyncio +async def test_make_multipart_http_request(): + # Mock request with file and form field + request = MagicMock(spec=Request) + request.method = "POST" + + # Mock form data + file_content = b"test file content" + file = BytesIO(file_content) + # Create SpooledTemporaryFile with content type headers + headers = {"content-type": "text/plain"} + upload_file = UploadFile(file=file, filename="test.txt", headers=headers) + upload_file.read = AsyncMock(return_value=file_content) + + form_data = {"file": upload_file, "text_field": "test value"} + request.form = AsyncMock(return_value=form_data) + + # Mock httpx client + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} + + async_client = MagicMock() + async_client.request = AsyncMock(return_value=mock_response) + + # Test the function + response = await HttpPassThroughEndpointHelpers.make_multipart_http_request( + request=request, + async_client=async_client, + url=httpx.URL("http://test.com"), + headers={}, + requested_query_params=None, + ) + + # Verify the response + assert response == mock_response + + # Verify the client call + async_client.request.assert_called_once() + call_args = async_client.request.call_args[1] + + assert call_args["method"] == "POST" + assert str(call_args["url"]) == "http://test.com" + assert isinstance(call_args["files"], dict) + assert isinstance(call_args["data"], dict) + assert call_args["data"]["text_field"] == "test value" diff --git a/tests/llm_translation/test_huggingface_chat_completion.py b/tests/llm_translation/test_huggingface_chat_completion.py index 9f1e89aeb1..7d498b96df 100644 --- a/tests/llm_translation/test_huggingface_chat_completion.py +++ b/tests/llm_translation/test_huggingface_chat_completion.py @@ -323,7 +323,8 @@ class TestHuggingFace(BaseLLMChatTest): model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct", messages=[{"role": "user", "content": "Hello"}], optional_params={}, - api_key="test_api_key" + api_key="test_api_key", + litellm_params={} ) assert headers["Authorization"] == "Bearer test_api_key" diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index d821fb415e..9118d94a6f 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -141,6 +141,7 @@ def test_build_vertex_schema(): [ ([{"googleSearch": {}}], "googleSearch"), ([{"googleSearchRetrieval": {}}], "googleSearchRetrieval"), + ([{"enterpriseWebSearch": {}}], "enterpriseWebSearch"), ([{"code_execution": {}}], "code_execution"), ], ) diff --git a/tests/local_testing/example.jsonl b/tests/local_testing/example.jsonl new file mode 100644 index 0000000000..fc3ca40808 --- /dev/null +++ b/tests/local_testing/example.jsonl @@ -0,0 +1,2 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}} diff --git a/tests/local_testing/test_gcs_bucket.py b/tests/local_testing/test_gcs_bucket.py index 1a8deed8a8..b64475c227 100644 --- a/tests/local_testing/test_gcs_bucket.py +++ b/tests/local_testing/test_gcs_bucket.py @@ -21,7 +21,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket import ( StandardLoggingPayload, ) from litellm.types.utils import StandardCallbackDynamicParams - +from unittest.mock import patch verbose_logger.setLevel(logging.DEBUG) @@ -687,3 +687,63 @@ async def test_basic_gcs_logger_with_folder_in_bucket_name(): # clean up if old_bucket_name is not None: os.environ["GCS_BUCKET_NAME"] = old_bucket_name + +@pytest.mark.skip(reason="This test is flaky on ci/cd") +def test_create_file_e2e(): + """ + Asserts 'create_file' is called with the correct arguments + """ + load_vertex_ai_credentials() + test_file_content = b"test audio content" + test_file = ("test.wav", test_file_content, "audio/wav") + + from litellm import create_file + response = create_file( + file=test_file, + purpose="user_data", + custom_llm_provider="vertex_ai", + ) + print("response", response) + assert response is not None + +@pytest.mark.skip(reason="This test is flaky on ci/cd") +def test_create_file_e2e_jsonl(): + """ + Asserts 'create_file' is called with the correct arguments + """ + load_vertex_ai_credentials() + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + + example_jsonl = [{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}},{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}] + + # Create and write to the file + file_path = "example.jsonl" + with open(file_path, "w") as f: + for item in example_jsonl: + f.write(json.dumps(item) + "\n") + + # Verify file content + with open(file_path, "r") as f: + content = f.read() + print("File content:", content) + assert len(content) > 0, "File is empty" + + from litellm import create_file + with patch.object(client, "post") as mock_create_file: + try: + response = create_file( + file=open(file_path, "rb"), + purpose="user_data", + custom_llm_provider="vertex_ai", + client=client, + ) + except Exception as e: + print("error", e) + + mock_create_file.assert_called_once() + + print(f"kwargs: {mock_create_file.call_args.kwargs}") + + assert mock_create_file.call_args.kwargs["data"] is not None and len(mock_create_file.call_args.kwargs["data"]) > 0 \ No newline at end of file diff --git a/tests/pass_through_tests/test_openai_assistants_passthrough.py b/tests/pass_through_tests/test_openai_assistants_passthrough.py index 694d3c090e..40361ab39f 100644 --- a/tests/pass_through_tests/test_openai_assistants_passthrough.py +++ b/tests/pass_through_tests/test_openai_assistants_passthrough.py @@ -2,14 +2,31 @@ import pytest import openai import aiohttp import asyncio +import tempfile from typing_extensions import override from openai import AssistantEventHandler + client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234") +def test_pass_through_file_operations(): + # Create a temporary file + with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file: + temp_file.write("This is a test file for the OpenAI Assistants API.") + temp_file.flush() + + # create a file + file = client.files.create( + file=open(temp_file.name, "rb"), + purpose="assistants", + ) + print("file created", file) + + # delete the file + delete_file = client.files.delete(file.id) + print("file deleted", delete_file) def test_openai_assistants_e2e_operations(): - assistant = client.beta.assistants.create( name="Math Tutor", instructions="You are a personal math tutor. Write and run code to answer math questions.", diff --git a/ui/litellm-dashboard/src/components/all_keys_table.tsx b/ui/litellm-dashboard/src/components/all_keys_table.tsx index b0313c241f..3a2bf61c5f 100644 --- a/ui/litellm-dashboard/src/components/all_keys_table.tsx +++ b/ui/litellm-dashboard/src/components/all_keys_table.tsx @@ -13,9 +13,12 @@ import { Organization, userListCall } from "./networking"; import { createTeamSearchFunction } from "./key_team_helpers/team_search_fn"; import { createOrgSearchFunction } from "./key_team_helpers/organization_search_fn"; import { useFilterLogic } from "./key_team_helpers/filter_logic"; +import { Setter } from "@/types"; +import { updateExistingKeys } from "@/utils/dataUtils"; interface AllKeysTableProps { keys: KeyResponse[]; + setKeys: Setter; isLoading?: boolean; pagination: { currentPage: number; @@ -87,6 +90,7 @@ const TeamFilter = ({ */ export function AllKeysTable({ keys, + setKeys, isLoading = false, pagination, onPageChange, @@ -364,6 +368,23 @@ export function AllKeysTable({ keyId={selectedKeyId} onClose={() => setSelectedKeyId(null)} keyData={keys.find(k => k.token === selectedKeyId)} + onKeyDataUpdate={(updatedKeyData) => { + setKeys(keys => keys.map(key => { + if (key.token === updatedKeyData.token) { + // The shape of key is different from that of + // updatedKeyData(received from keyUpdateCall in networking.tsx). + // Hence, we can't replace key with updatedKeys since it might lead + // to unintended bugs/behaviors. + // So instead, we only update fields that are present in both. + return updateExistingKeys(key, updatedKeyData) + } + + return key + })) + }} + onDelete={() => { + setKeys(keys => keys.filter(key => key.token !== selectedKeyId)) + }} accessToken={accessToken} userID={userID} userRole={userRole} diff --git a/ui/litellm-dashboard/src/components/key_info_view.tsx b/ui/litellm-dashboard/src/components/key_info_view.tsx index 9d50be6cf7..b7ebdc651a 100644 --- a/ui/litellm-dashboard/src/components/key_info_view.tsx +++ b/ui/litellm-dashboard/src/components/key_info_view.tsx @@ -27,13 +27,15 @@ interface KeyInfoViewProps { keyId: string; onClose: () => void; keyData: KeyResponse | undefined; + onKeyDataUpdate?: (data: Partial) => void; + onDelete?: () => void; accessToken: string | null; userID: string | null; userRole: string | null; teams: any[] | null; } -export default function KeyInfoView({ keyId, onClose, keyData, accessToken, userID, userRole, teams }: KeyInfoViewProps) { +export default function KeyInfoView({ keyId, onClose, keyData, accessToken, userID, userRole, teams, onKeyDataUpdate, onDelete }: KeyInfoViewProps) { const [isEditing, setIsEditing] = useState(false); const [form] = Form.useForm(); const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); @@ -93,6 +95,9 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user } const newKeyValues = await keyUpdateCall(accessToken, formValues); + if (onKeyDataUpdate) { + onKeyDataUpdate(newKeyValues) + } message.success("Key updated successfully"); setIsEditing(false); // Refresh key data here if needed @@ -107,6 +112,9 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user if (!accessToken) return; await keyDeleteCall(accessToken as string, keyData.token); message.success("Key deleted successfully"); + if (onDelete) { + onDelete() + } onClose(); } catch (error) { console.error("Error deleting the key:", error); diff --git a/ui/litellm-dashboard/src/components/key_team_helpers/key_list.tsx b/ui/litellm-dashboard/src/components/key_team_helpers/key_list.tsx index 4c2a18d2b5..4ca0ea5720 100644 --- a/ui/litellm-dashboard/src/components/key_team_helpers/key_list.tsx +++ b/ui/litellm-dashboard/src/components/key_team_helpers/key_list.tsx @@ -1,5 +1,6 @@ import { useState, useEffect } from 'react'; import { keyListCall, Organization } from '../networking'; +import { Setter } from '@/types'; export interface Team { team_id: string; @@ -94,13 +95,14 @@ totalPages: number; totalCount: number; } + interface UseKeyListReturn { keys: KeyResponse[]; isLoading: boolean; error: Error | null; pagination: PaginationData; refresh: (params?: Record) => Promise; -setKeys: (newKeysOrUpdater: KeyResponse[] | ((prevKeys: KeyResponse[]) => KeyResponse[])) => void; +setKeys: Setter; } const useKeyList = ({ diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index ac79237fb8..025f0c72c4 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -4,6 +4,7 @@ import { all_admin_roles } from "@/utils/roles"; import { message } from "antd"; import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types"; +import { Team } from "./key_team_helpers/key_list"; const isLocal = process.env.NODE_ENV === "development"; export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null; @@ -2983,7 +2984,7 @@ export const teamUpdateCall = async ( console.error("Error response from the server:", errorData); throw new Error("Network response was not ok"); } - const data = await response.json(); + const data = await response.json() as { data: Team, team_id: string }; console.log("Update Team Response:", data); return data; // Handle success - you might want to update some state or UI based on the created key diff --git a/ui/litellm-dashboard/src/components/team/team_info.tsx b/ui/litellm-dashboard/src/components/team/team_info.tsx index e04680b53a..20e9d23ccf 100644 --- a/ui/litellm-dashboard/src/components/team/team_info.tsx +++ b/ui/litellm-dashboard/src/components/team/team_info.tsx @@ -30,6 +30,7 @@ import { PencilAltIcon, PlusIcon, TrashIcon } from "@heroicons/react/outline"; import MemberModal from "./edit_membership"; import UserSearchModal from "@/components/common_components/user_search_modal"; import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key"; +import { Team } from "../key_team_helpers/key_list"; interface TeamData { @@ -69,6 +70,7 @@ interface TeamInfoProps { is_proxy_admin: boolean; userModels: string[]; editTeam: boolean; + onUpdate?: (team: Team) => void } const TeamInfoView: React.FC = ({ @@ -78,7 +80,8 @@ const TeamInfoView: React.FC = ({ is_team_admin, is_proxy_admin, userModels, - editTeam + editTeam, + onUpdate }) => { const [teamData, setTeamData] = useState(null); const [loading, setLoading] = useState(true); @@ -199,7 +202,10 @@ const TeamInfoView: React.FC = ({ }; const response = await teamUpdateCall(accessToken, updateData); - + if (onUpdate) { + onUpdate(response.data) + } + message.success("Team settings updated successfully"); setIsEditing(false); fetchTeamInfo(); diff --git a/ui/litellm-dashboard/src/components/teams.tsx b/ui/litellm-dashboard/src/components/teams.tsx index 6f516f06e2..7e3b607267 100644 --- a/ui/litellm-dashboard/src/components/teams.tsx +++ b/ui/litellm-dashboard/src/components/teams.tsx @@ -84,6 +84,7 @@ import { modelAvailableCall, teamListCall } from "./networking"; +import { updateExistingKeys } from "@/utils/dataUtils"; const getOrganizationModels = (organization: Organization | null, userModels: string[]) => { let tempModelsToPick = []; @@ -321,6 +322,22 @@ const Teams: React.FC = ({ {selectedTeamId ? ( { + setTeams(teams => { + if (teams == null) { + return teams; + } + + return teams.map(team => { + if (data.team_id === team.team_id) { + return updateExistingKeys(team, data) + } + + return team + }) + }) + + }} onClose={() => { setSelectedTeamId(null); setEditTeam(false); diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index f3661c8c64..57467efa18 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -418,6 +418,7 @@ const ViewKeyTable: React.FC = ({
= (newValueOrUpdater: T | ((previousValue: T) => T)) => void \ No newline at end of file diff --git a/ui/litellm-dashboard/src/utils/dataUtils.ts b/ui/litellm-dashboard/src/utils/dataUtils.ts new file mode 100644 index 0000000000..f51940f2ef --- /dev/null +++ b/ui/litellm-dashboard/src/utils/dataUtils.ts @@ -0,0 +1,14 @@ +export function updateExistingKeys( + target: Source, + source: Object +): Source { + const clonedTarget = structuredClone(target); + + for (const [key, value] of Object.entries(source)) { + if (key in clonedTarget) { + (clonedTarget as any)[key] = value; + } + } + + return clonedTarget; +}