diff --git a/.circleci/config.yml b/.circleci/config.yml
index b1e08084fa..14a22a5995 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -610,6 +610,8 @@ jobs:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
+ pip install wheel
+ pip install --upgrade pip wheel setuptools
python -m pip install -r requirements.txt
pip install "pytest==7.3.1"
pip install "respx==0.21.1"
diff --git a/deploy/charts/litellm-helm/templates/service.yaml b/deploy/charts/litellm-helm/templates/service.yaml
index 40e7f27f16..d8d81e78c8 100644
--- a/deploy/charts/litellm-helm/templates/service.yaml
+++ b/deploy/charts/litellm-helm/templates/service.yaml
@@ -2,6 +2,10 @@ apiVersion: v1
kind: Service
metadata:
name: {{ include "litellm.fullname" . }}
+ {{- with .Values.service.annotations }}
+ annotations:
+ {{- toYaml . | nindent 4 }}
+ {{- end }}
labels:
{{- include "litellm.labels" . | nindent 4 }}
spec:
diff --git a/docs/my-website/docs/providers/gemini.md b/docs/my-website/docs/providers/gemini.md
index 42783286f1..db63d33d8d 100644
--- a/docs/my-website/docs/providers/gemini.md
+++ b/docs/my-website/docs/providers/gemini.md
@@ -438,6 +438,179 @@ assert isinstance(
```
+### Google Search Tool
+
+
+
+
+```python
+from litellm import completion
+import os
+
+os.environ["GEMINI_API_KEY"] = ".."
+
+tools = [{"googleSearch": {}}] # 👈 ADD GOOGLE SEARCH
+
+response = completion(
+ model="gemini/gemini-2.0-flash",
+ messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
+ tools=tools,
+)
+
+print(response)
+```
+
+
+
+
+1. Setup config.yaml
+```yaml
+model_list:
+ - model_name: gemini-2.0-flash
+ litellm_params:
+ model: gemini/gemini-2.0-flash
+ api_key: os.environ/GEMINI_API_KEY
+```
+
+2. Start Proxy
+```bash
+$ litellm --config /path/to/config.yaml
+```
+
+3. Make Request!
+```bash
+curl -X POST 'http://0.0.0.0:4000/chat/completions' \
+-H 'Content-Type: application/json' \
+-H 'Authorization: Bearer sk-1234' \
+-d '{
+ "model": "gemini-2.0-flash",
+ "messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
+ "tools": [{"googleSearch": {}}]
+}
+'
+```
+
+
+
+
+### Google Search Retrieval
+
+
+
+
+
+```python
+from litellm import completion
+import os
+
+os.environ["GEMINI_API_KEY"] = ".."
+
+tools = [{"googleSearchRetrieval": {}}] # 👈 ADD GOOGLE SEARCH
+
+response = completion(
+ model="gemini/gemini-2.0-flash",
+ messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
+ tools=tools,
+)
+
+print(response)
+```
+
+
+
+
+1. Setup config.yaml
+```yaml
+model_list:
+ - model_name: gemini-2.0-flash
+ litellm_params:
+ model: gemini/gemini-2.0-flash
+ api_key: os.environ/GEMINI_API_KEY
+```
+
+2. Start Proxy
+```bash
+$ litellm --config /path/to/config.yaml
+```
+
+3. Make Request!
+```bash
+curl -X POST 'http://0.0.0.0:4000/chat/completions' \
+-H 'Content-Type: application/json' \
+-H 'Authorization: Bearer sk-1234' \
+-d '{
+ "model": "gemini-2.0-flash",
+ "messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
+ "tools": [{"googleSearchRetrieval": {}}]
+}
+'
+```
+
+
+
+
+
+### Code Execution Tool
+
+
+
+
+
+```python
+from litellm import completion
+import os
+
+os.environ["GEMINI_API_KEY"] = ".."
+
+tools = [{"codeExecution": {}}] # 👈 ADD GOOGLE SEARCH
+
+response = completion(
+ model="gemini/gemini-2.0-flash",
+ messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
+ tools=tools,
+)
+
+print(response)
+```
+
+
+
+
+1. Setup config.yaml
+```yaml
+model_list:
+ - model_name: gemini-2.0-flash
+ litellm_params:
+ model: gemini/gemini-2.0-flash
+ api_key: os.environ/GEMINI_API_KEY
+```
+
+2. Start Proxy
+```bash
+$ litellm --config /path/to/config.yaml
+```
+
+3. Make Request!
+```bash
+curl -X POST 'http://0.0.0.0:4000/chat/completions' \
+-H 'Content-Type: application/json' \
+-H 'Authorization: Bearer sk-1234' \
+-d '{
+ "model": "gemini-2.0-flash",
+ "messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
+ "tools": [{"codeExecution": {}}]
+}
+'
+```
+
+
+
+
+
+
+
+
+
## JSON Mode
diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md
index ab13a51137..cdd3fce6c6 100644
--- a/docs/my-website/docs/providers/vertex.md
+++ b/docs/my-website/docs/providers/vertex.md
@@ -398,6 +398,8 @@ curl http://localhost:4000/v1/chat/completions \
+You can also use the `enterpriseWebSearch` tool for an [enterprise compliant search](https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise).
+
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
diff --git a/litellm/litellm_core_utils/get_litellm_params.py b/litellm/litellm_core_utils/get_litellm_params.py
index 4f2f43f0de..f40f1ae4c7 100644
--- a/litellm/litellm_core_utils/get_litellm_params.py
+++ b/litellm/litellm_core_utils/get_litellm_params.py
@@ -110,5 +110,8 @@ def get_litellm_params(
"azure_password": kwargs.get("azure_password"),
"max_retries": max_retries,
"timeout": kwargs.get("timeout"),
+ "bucket_name": kwargs.get("bucket_name"),
+ "vertex_credentials": kwargs.get("vertex_credentials"),
+ "vertex_project": kwargs.get("vertex_project"),
}
return litellm_params
diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py
index 0f2d0da388..44b680d487 100644
--- a/litellm/litellm_core_utils/prompt_templates/common_utils.py
+++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py
@@ -2,7 +2,10 @@
Common utility functions used for translating messages across providers
"""
-from typing import Dict, List, Literal, Optional, Union, cast
+import io
+import mimetypes
+from os import PathLike
+from typing import Dict, List, Literal, Mapping, Optional, Union, cast
from litellm.types.llms.openai import (
AllMessageValues,
@@ -10,7 +13,13 @@ from litellm.types.llms.openai import (
ChatCompletionFileObject,
ChatCompletionUserMessage,
)
-from litellm.types.utils import Choices, ModelResponse, StreamingChoices
+from litellm.types.utils import (
+ Choices,
+ ExtractedFileData,
+ FileTypes,
+ ModelResponse,
+ StreamingChoices,
+)
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
content="Please continue.", role="user"
@@ -350,6 +359,68 @@ def update_messages_with_model_file_ids(
return messages
+def extract_file_data(file_data: FileTypes) -> ExtractedFileData:
+ """
+ Extracts and processes file data from various input formats.
+
+ Args:
+ file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content
+
+ Returns:
+ ExtractedFileData containing:
+ - filename: Name of the file if provided
+ - content: The file content in bytes
+ - content_type: MIME type of the file
+ - headers: Any additional headers
+ """
+ # Parse the file_data based on its type
+ filename = None
+ file_content = None
+ content_type = None
+ file_headers: Mapping[str, str] = {}
+
+ if isinstance(file_data, tuple):
+ if len(file_data) == 2:
+ filename, file_content = file_data
+ elif len(file_data) == 3:
+ filename, file_content, content_type = file_data
+ elif len(file_data) == 4:
+ filename, file_content, content_type, file_headers = file_data
+ else:
+ file_content = file_data
+ # Convert content to bytes
+ if isinstance(file_content, (str, PathLike)):
+ # If it's a path, open and read the file
+ with open(file_content, "rb") as f:
+ content = f.read()
+ elif isinstance(file_content, io.IOBase):
+ # If it's a file-like object
+ content = file_content.read()
+
+ if isinstance(content, str):
+ content = content.encode("utf-8")
+ # Reset file pointer to beginning
+ file_content.seek(0)
+ elif isinstance(file_content, bytes):
+ content = file_content
+ else:
+ raise ValueError(f"Unsupported file content type: {type(file_content)}")
+
+ # Use provided content type or guess based on filename
+ if not content_type:
+ content_type = (
+ mimetypes.guess_type(filename)[0]
+ if filename
+ else "application/octet-stream"
+ )
+
+ return ExtractedFileData(
+ filename=filename,
+ content=content,
+ content_type=content_type,
+ headers=file_headers,
+ )
+
def unpack_defs(schema, defs):
properties = schema.get("properties", None)
if properties is None:
@@ -381,3 +452,4 @@ def unpack_defs(schema, defs):
unpack_defs(ref, defs)
value["items"] = ref
continue
+
diff --git a/litellm/llms/aiohttp_openai/chat/transformation.py b/litellm/llms/aiohttp_openai/chat/transformation.py
index af073fe8e3..c2d4e5adcd 100644
--- a/litellm/llms/aiohttp_openai/chat/transformation.py
+++ b/litellm/llms/aiohttp_openai/chat/transformation.py
@@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py
index f2a5542dcd..44567facf9 100644
--- a/litellm/llms/anthropic/chat/handler.py
+++ b/litellm/llms/anthropic/chat/handler.py
@@ -301,6 +301,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
messages=messages,
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
+ litellm_params=litellm_params,
)
config = ProviderConfigManager.get_provider_chat_config(
diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py
index 8a2048f95a..9b66249630 100644
--- a/litellm/llms/anthropic/chat/transformation.py
+++ b/litellm/llms/anthropic/chat/transformation.py
@@ -868,6 +868,7 @@ class AnthropicConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py
index e4e04df4d6..9e3287aa8a 100644
--- a/litellm/llms/anthropic/completion/transformation.py
+++ b/litellm/llms/anthropic/completion/transformation.py
@@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/azure/chat/gpt_transformation.py b/litellm/llms/azure/chat/gpt_transformation.py
index e30d68f97d..ea61ef2c9a 100644
--- a/litellm/llms/azure/chat/gpt_transformation.py
+++ b/litellm/llms/azure/chat/gpt_transformation.py
@@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/azure_ai/chat/transformation.py b/litellm/llms/azure_ai/chat/transformation.py
index 007a4303c8..839f875f75 100644
--- a/litellm/llms/azure_ai/chat/transformation.py
+++ b/litellm/llms/azure_ai/chat/transformation.py
@@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py
index 5279a44201..fa278c805e 100644
--- a/litellm/llms/base_llm/chat/transformation.py
+++ b/litellm/llms/base_llm/chat/transformation.py
@@ -262,6 +262,7 @@ class BaseConfig(ABC):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/base_llm/files/transformation.py b/litellm/llms/base_llm/files/transformation.py
index 0f1f46352f..9925004c89 100644
--- a/litellm/llms/base_llm/files/transformation.py
+++ b/litellm/llms/base_llm/files/transformation.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import TYPE_CHECKING, Any, List, Optional
+from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
@@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig):
) -> List[OpenAICreateFileRequestOptionalParams]:
pass
- def get_complete_url(
+ def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
- stream: Optional[bool] = None,
- ) -> str:
- """
- OPTIONAL
-
- Get the complete url for the request
-
- Some providers need `model` in `api_base`
- """
- return api_base or ""
+ data: CreateFileRequest,
+ ):
+ return self.get_complete_url(
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ )
@abstractmethod
def transform_create_file_request(
@@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig):
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
- ) -> dict:
+ ) -> Union[dict, str, bytes]:
pass
@abstractmethod
diff --git a/litellm/llms/base_llm/image_variations/transformation.py b/litellm/llms/base_llm/image_variations/transformation.py
index 3ed446a84e..60444d0fb7 100644
--- a/litellm/llms/base_llm/image_variations/transformation.py
+++ b/litellm/llms/base_llm/image_variations/transformation.py
@@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py
index 8ce2c4818b..fbe2dc4937 100644
--- a/litellm/llms/bedrock/chat/converse_transformation.py
+++ b/litellm/llms/bedrock/chat/converse_transformation.py
@@ -831,6 +831,7 @@ class AmazonConverseConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py
index cb12f779cc..67194e83e7 100644
--- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py
+++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py
@@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/clarifai/chat/transformation.py b/litellm/llms/clarifai/chat/transformation.py
index 916da73883..73be89fc6e 100644
--- a/litellm/llms/clarifai/chat/transformation.py
+++ b/litellm/llms/clarifai/chat/transformation.py
@@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/cloudflare/chat/transformation.py b/litellm/llms/cloudflare/chat/transformation.py
index 1874bb5115..9e59782bf7 100644
--- a/litellm/llms/cloudflare/chat/transformation.py
+++ b/litellm/llms/cloudflare/chat/transformation.py
@@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py
index 70677214a7..5dd44aca80 100644
--- a/litellm/llms/cohere/chat/transformation.py
+++ b/litellm/llms/cohere/chat/transformation.py
@@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/cohere/completion/transformation.py b/litellm/llms/cohere/completion/transformation.py
index bdfcda020e..f96ef89d3c 100644
--- a/litellm/llms/cohere/completion/transformation.py
+++ b/litellm/llms/cohere/completion/transformation.py
@@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py
index 72092cf261..13141fc19a 100644
--- a/litellm/llms/custom_httpx/aiohttp_handler.py
+++ b/litellm/llms/custom_httpx/aiohttp_handler.py
@@ -229,6 +229,7 @@ class BaseLLMAIOHTTPHandler:
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
api_base=api_base,
)
@@ -498,6 +499,7 @@ class BaseLLMAIOHTTPHandler:
model=model,
messages=[{"role": "user", "content": "test"}],
optional_params=optional_params,
+ litellm_params=litellm_params,
api_base=api_base,
)
diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py
index 23d7fe4b4d..f1aa5627dc 100644
--- a/litellm/llms/custom_httpx/http_handler.py
+++ b/litellm/llms/custom_httpx/http_handler.py
@@ -192,7 +192,7 @@ class AsyncHTTPHandler:
async def post(
self,
url: str,
- data: Optional[Union[dict, str]] = None, # type: ignore
+ data: Optional[Union[dict, str, bytes]] = None, # type: ignore
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@@ -427,7 +427,7 @@ class AsyncHTTPHandler:
self,
url: str,
client: httpx.AsyncClient,
- data: Optional[Union[dict, str]] = None, # type: ignore
+ data: Optional[Union[dict, str, bytes]] = None, # type: ignore
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@@ -527,7 +527,7 @@ class HTTPHandler:
def post(
self,
url: str,
- data: Optional[Union[dict, str]] = None,
+ data: Optional[Union[dict, str, bytes]] = None,
json: Optional[Union[dict, str, List]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@@ -573,7 +573,6 @@ class HTTPHandler:
setattr(e, "text", error_text)
setattr(e, "status_code", e.response.status_code)
-
raise e
except Exception as e:
raise e
diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py
index 5778f0228f..b7c72e89ef 100644
--- a/litellm/llms/custom_httpx/llm_http_handler.py
+++ b/litellm/llms/custom_httpx/llm_http_handler.py
@@ -247,6 +247,7 @@ class BaseLLMHTTPHandler:
messages=messages,
optional_params=optional_params,
api_base=api_base,
+ litellm_params=litellm_params,
)
api_base = provider_config.get_complete_url(
@@ -625,6 +626,7 @@ class BaseLLMHTTPHandler:
model=model,
messages=[],
optional_params=optional_params,
+ litellm_params=litellm_params,
)
api_base = provider_config.get_complete_url(
@@ -896,6 +898,7 @@ class BaseLLMHTTPHandler:
model=model,
messages=[],
optional_params=optional_params,
+ litellm_params=litellm_params,
)
if client is None or not isinstance(client, HTTPHandler):
@@ -1228,15 +1231,19 @@ class BaseLLMHTTPHandler:
model="",
messages=[],
optional_params={},
+ litellm_params=litellm_params,
)
- api_base = provider_config.get_complete_url(
+ api_base = provider_config.get_complete_file_url(
api_base=api_base,
api_key=api_key,
model="",
optional_params={},
litellm_params=litellm_params,
+ data=create_file_data,
)
+ if api_base is None:
+ raise ValueError("api_base is required for create_file")
# Get the transformed request data for both steps
transformed_request = provider_config.transform_create_file_request(
@@ -1263,48 +1270,57 @@ class BaseLLMHTTPHandler:
else:
sync_httpx_client = client
- try:
- # Step 1: Initial request to get upload URL
- initial_response = sync_httpx_client.post(
- url=api_base,
- headers={
- **headers,
- **transformed_request["initial_request"]["headers"],
- },
- data=json.dumps(transformed_request["initial_request"]["data"]),
- timeout=timeout,
- )
-
- # Extract upload URL from response headers
- upload_url = initial_response.headers.get("X-Goog-Upload-URL")
-
- if not upload_url:
- raise ValueError("Failed to get upload URL from initial request")
-
- # Step 2: Upload the actual file
+ if isinstance(transformed_request, str) or isinstance(
+ transformed_request, bytes
+ ):
upload_response = sync_httpx_client.post(
- url=upload_url,
- headers=transformed_request["upload_request"]["headers"],
- data=transformed_request["upload_request"]["data"],
+ url=api_base,
+ headers=headers,
+ data=transformed_request,
timeout=timeout,
)
+ else:
+ try:
+ # Step 1: Initial request to get upload URL
+ initial_response = sync_httpx_client.post(
+ url=api_base,
+ headers={
+ **headers,
+ **transformed_request["initial_request"]["headers"],
+ },
+ data=json.dumps(transformed_request["initial_request"]["data"]),
+ timeout=timeout,
+ )
- return provider_config.transform_create_file_response(
- model=None,
- raw_response=upload_response,
- logging_obj=logging_obj,
- litellm_params=litellm_params,
- )
+ # Extract upload URL from response headers
+ upload_url = initial_response.headers.get("X-Goog-Upload-URL")
- except Exception as e:
- raise self._handle_error(
- e=e,
- provider_config=provider_config,
- )
+ if not upload_url:
+ raise ValueError("Failed to get upload URL from initial request")
+
+ # Step 2: Upload the actual file
+ upload_response = sync_httpx_client.post(
+ url=upload_url,
+ headers=transformed_request["upload_request"]["headers"],
+ data=transformed_request["upload_request"]["data"],
+ timeout=timeout,
+ )
+ except Exception as e:
+ raise self._handle_error(
+ e=e,
+ provider_config=provider_config,
+ )
+
+ return provider_config.transform_create_file_response(
+ model=None,
+ raw_response=upload_response,
+ logging_obj=logging_obj,
+ litellm_params=litellm_params,
+ )
async def async_create_file(
self,
- transformed_request: dict,
+ transformed_request: Union[bytes, str, dict],
litellm_params: dict,
provider_config: BaseFilesConfig,
headers: dict,
@@ -1323,45 +1339,54 @@ class BaseLLMHTTPHandler:
else:
async_httpx_client = client
- try:
- # Step 1: Initial request to get upload URL
- initial_response = await async_httpx_client.post(
- url=api_base,
- headers={
- **headers,
- **transformed_request["initial_request"]["headers"],
- },
- data=json.dumps(transformed_request["initial_request"]["data"]),
- timeout=timeout,
- )
-
- # Extract upload URL from response headers
- upload_url = initial_response.headers.get("X-Goog-Upload-URL")
-
- if not upload_url:
- raise ValueError("Failed to get upload URL from initial request")
-
- # Step 2: Upload the actual file
+ if isinstance(transformed_request, str) or isinstance(
+ transformed_request, bytes
+ ):
upload_response = await async_httpx_client.post(
- url=upload_url,
- headers=transformed_request["upload_request"]["headers"],
- data=transformed_request["upload_request"]["data"],
+ url=api_base,
+ headers=headers,
+ data=transformed_request,
timeout=timeout,
)
+ else:
+ try:
+ # Step 1: Initial request to get upload URL
+ initial_response = await async_httpx_client.post(
+ url=api_base,
+ headers={
+ **headers,
+ **transformed_request["initial_request"]["headers"],
+ },
+ data=json.dumps(transformed_request["initial_request"]["data"]),
+ timeout=timeout,
+ )
- return provider_config.transform_create_file_response(
- model=None,
- raw_response=upload_response,
- logging_obj=logging_obj,
- litellm_params=litellm_params,
- )
+ # Extract upload URL from response headers
+ upload_url = initial_response.headers.get("X-Goog-Upload-URL")
- except Exception as e:
- verbose_logger.exception(f"Error creating file: {e}")
- raise self._handle_error(
- e=e,
- provider_config=provider_config,
- )
+ if not upload_url:
+ raise ValueError("Failed to get upload URL from initial request")
+
+ # Step 2: Upload the actual file
+ upload_response = await async_httpx_client.post(
+ url=upload_url,
+ headers=transformed_request["upload_request"]["headers"],
+ data=transformed_request["upload_request"]["data"],
+ timeout=timeout,
+ )
+ except Exception as e:
+ verbose_logger.exception(f"Error creating file: {e}")
+ raise self._handle_error(
+ e=e,
+ provider_config=provider_config,
+ )
+
+ return provider_config.transform_create_file_response(
+ model=None,
+ raw_response=upload_response,
+ logging_obj=logging_obj,
+ litellm_params=litellm_params,
+ )
def list_files(self):
"""
diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py
index 1940f09608..6f5738fb4b 100644
--- a/litellm/llms/databricks/chat/transformation.py
+++ b/litellm/llms/databricks/chat/transformation.py
@@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/deepgram/audio_transcription/transformation.py b/litellm/llms/deepgram/audio_transcription/transformation.py
index b4803576e0..f1b18808f7 100644
--- a/litellm/llms/deepgram/audio_transcription/transformation.py
+++ b/litellm/llms/deepgram/audio_transcription/transformation.py
@@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/fireworks_ai/common_utils.py b/litellm/llms/fireworks_ai/common_utils.py
index 293403b133..17aa67b525 100644
--- a/litellm/llms/fireworks_ai/common_utils.py
+++ b/litellm/llms/fireworks_ai/common_utils.py
@@ -41,6 +41,7 @@ class FireworksAIMixin:
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/gemini/common_utils.py b/litellm/llms/gemini/common_utils.py
index ace24e982f..fef41f7d58 100644
--- a/litellm/llms/gemini/common_utils.py
+++ b/litellm/llms/gemini/common_utils.py
@@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/gemini/files/transformation.py b/litellm/llms/gemini/files/transformation.py
index a1f99c6903..e98e76dabc 100644
--- a/litellm/llms/gemini/files/transformation.py
+++ b/litellm/llms/gemini/files/transformation.py
@@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API.
For vertex ai, check out the vertex_ai/files/handler.py file.
"""
import time
-from typing import List, Mapping, Optional
+from typing import List, Optional
import httpx
from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
@@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
if file_data is None:
raise ValueError("File data is required")
- # Parse the file_data based on its type
- filename = None
- file_content = None
- content_type = None
- file_headers: Mapping[str, str] = {}
-
- if isinstance(file_data, tuple):
- if len(file_data) == 2:
- filename, file_content = file_data
- elif len(file_data) == 3:
- filename, file_content, content_type = file_data
- elif len(file_data) == 4:
- filename, file_content, content_type, file_headers = file_data
- else:
- file_content = file_data
-
- # Handle the file content based on its type
- import io
- from os import PathLike
-
- # Convert content to bytes
- if isinstance(file_content, (str, PathLike)):
- # If it's a path, open and read the file
- with open(file_content, "rb") as f:
- content = f.read()
- elif isinstance(file_content, io.IOBase):
- # If it's a file-like object
- content = file_content.read()
- if isinstance(content, str):
- content = content.encode("utf-8")
- elif isinstance(file_content, bytes):
- content = file_content
- else:
- raise ValueError(f"Unsupported file content type: {type(file_content)}")
+ # Use the common utility function to extract file data
+ extracted_data = extract_file_data(file_data)
# Get file size
- file_size = len(content)
-
- # Use provided content type or guess based on filename
- if not content_type:
- import mimetypes
-
- content_type = (
- mimetypes.guess_type(filename)[0]
- if filename
- else "application/octet-stream"
- )
+ file_size = len(extracted_data["content"])
# Step 1: Initial resumable upload request
headers = {
"X-Goog-Upload-Protocol": "resumable",
"X-Goog-Upload-Command": "start",
"X-Goog-Upload-Header-Content-Length": str(file_size),
- "X-Goog-Upload-Header-Content-Type": content_type,
+ "X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
"Content-Type": "application/json",
}
- headers.update(file_headers) # Add any custom headers
+ headers.update(extracted_data["headers"]) # Add any custom headers
# Initial metadata request body
- initial_data = {"file": {"display_name": filename or str(int(time.time()))}}
+ initial_data = {
+ "file": {
+ "display_name": extracted_data["filename"] or str(int(time.time()))
+ }
+ }
# Step 2: Actual file upload data
upload_headers = {
@@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
return {
"initial_request": {"headers": headers, "data": initial_data},
- "upload_request": {"headers": upload_headers, "data": content},
+ "upload_request": {
+ "headers": upload_headers,
+ "data": extracted_data["content"],
+ },
}
def transform_create_file_response(
diff --git a/litellm/llms/huggingface/chat/transformation.py b/litellm/llms/huggingface/chat/transformation.py
index c84f03ab93..0ad93be763 100644
--- a/litellm/llms/huggingface/chat/transformation.py
+++ b/litellm/llms/huggingface/chat/transformation.py
@@ -1,6 +1,6 @@
import logging
import os
-from typing import TYPE_CHECKING, Any, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
@@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
-
logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
@@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
headers: dict,
model: str,
messages: List[AllMessageValues],
- optional_params: dict,
+ optional_params: Dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
@@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
- return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
+ return HuggingFaceError(
+ status_code=status_code, message=error_message, headers=headers
+ )
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
"""
@@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
if api_base is not None:
complete_url = api_base
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
- complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
+ complete_url = str(os.getenv("HF_API_BASE")) or str(
+ os.getenv("HUGGINGFACE_API_BASE")
+ )
elif model.startswith(("http://", "https://")):
complete_url = model
# 4. Default construction with provider
@@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
)
mapped_model = provider_mapping["providerId"]
messages = self._transform_messages(messages=messages, model=mapped_model)
- return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
+ return dict(
+ ChatCompletionRequest(
+ model=mapped_model, messages=messages, **optional_params
+ )
+ )
diff --git a/litellm/llms/huggingface/embedding/handler.py b/litellm/llms/huggingface/embedding/handler.py
index 7277fbd0e3..bfd73c1346 100644
--- a/litellm/llms/huggingface/embedding/handler.py
+++ b/litellm/llms/huggingface/embedding/handler.py
@@ -1,15 +1,6 @@
import json
import os
-from typing import (
- Any,
- Callable,
- Dict,
- List,
- Literal,
- Optional,
- Union,
- get_args,
-)
+from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
import httpx
@@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
]
-
-def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
+def get_hf_task_embedding_for_model(
+ model: str, task_type: Optional[str], api_base: str
+) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
@@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
return pipeline_tag
-async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
+async def async_get_hf_task_embedding_for_model(
+ model: str, task_type: Optional[str], api_base: str
+) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
@@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
input: List,
optional_params: dict,
) -> dict:
- hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
+ hf_task = await async_get_hf_task_embedding_for_model(
+ model=model, task_type=task_type, api_base=HF_HUB_URL
+ )
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
@@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
- hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
+ hf_task = get_hf_task_embedding_for_model(
+ model=model, task_type=task_type, api_base=HF_HUB_URL
+ )
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
@@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
+ litellm_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
@@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
model=model,
optional_params=optional_params,
messages=[],
+ litellm_params=litellm_params,
)
task_type = optional_params.pop("input_type", None)
- task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
+ task = get_hf_task_embedding_for_model(
+ model=model, task_type=task_type, api_base=HF_HUB_URL
+ )
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
@@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
- embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
+ embed_url = (
+ f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
+ )
## ROUTING ##
if aembedding is True:
diff --git a/litellm/llms/huggingface/embedding/transformation.py b/litellm/llms/huggingface/embedding/transformation.py
index f803157768..60bd5dcd61 100644
--- a/litellm/llms/huggingface/embedding/transformation.py
+++ b/litellm/llms/huggingface/embedding/transformation.py
@@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
diff --git a/litellm/llms/nlp_cloud/chat/handler.py b/litellm/llms/nlp_cloud/chat/handler.py
index b0abdda587..b0563d8b55 100644
--- a/litellm/llms/nlp_cloud/chat/handler.py
+++ b/litellm/llms/nlp_cloud/chat/handler.py
@@ -36,6 +36,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
)
## Load Config
diff --git a/litellm/llms/nlp_cloud/chat/transformation.py b/litellm/llms/nlp_cloud/chat/transformation.py
index b7967249ab..8037a45832 100644
--- a/litellm/llms/nlp_cloud/chat/transformation.py
+++ b/litellm/llms/nlp_cloud/chat/transformation.py
@@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py
index 64544bd269..789b728337 100644
--- a/litellm/llms/ollama/completion/transformation.py
+++ b/litellm/llms/ollama/completion/transformation.py
@@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/oobabooga/chat/oobabooga.py b/litellm/llms/oobabooga/chat/oobabooga.py
index 8829d2233e..5eb68a03d4 100644
--- a/litellm/llms/oobabooga/chat/oobabooga.py
+++ b/litellm/llms/oobabooga/chat/oobabooga.py
@@ -32,6 +32,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
)
if "https" in model:
completion_url = model
@@ -123,6 +124,7 @@ def embedding(
model=model,
messages=[],
optional_params=optional_params,
+ litellm_params={},
)
response = litellm.module_level_client.post(
embeddings_url, headers=headers, json=data
diff --git a/litellm/llms/oobabooga/chat/transformation.py b/litellm/llms/oobabooga/chat/transformation.py
index 6fd56f934e..e87b70130c 100644
--- a/litellm/llms/oobabooga/chat/transformation.py
+++ b/litellm/llms/oobabooga/chat/transformation.py
@@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py
index fcab43901a..434214639e 100644
--- a/litellm/llms/openai/chat/gpt_transformation.py
+++ b/litellm/llms/openai/chat/gpt_transformation.py
@@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py
index 3b6be1a034..13412ef96a 100644
--- a/litellm/llms/openai/openai.py
+++ b/litellm/llms/openai/openai.py
@@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/openai/transcriptions/whisper_transformation.py b/litellm/llms/openai/transcriptions/whisper_transformation.py
index 2d3d611dac..c0ccc71579 100644
--- a/litellm/llms/openai/transcriptions/whisper_transformation.py
+++ b/litellm/llms/openai/transcriptions/whisper_transformation.py
@@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py
index a9e37d27fc..24910cba8f 100644
--- a/litellm/llms/petals/completion/transformation.py
+++ b/litellm/llms/petals/completion/transformation.py
@@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/predibase/chat/handler.py b/litellm/llms/predibase/chat/handler.py
index cd80fa53e4..79936764ac 100644
--- a/litellm/llms/predibase/chat/handler.py
+++ b/litellm/llms/predibase/chat/handler.py
@@ -228,10 +228,10 @@ class PredibaseChatCompletion:
api_key: str,
logging_obj,
optional_params: dict,
+ litellm_params: dict,
tenant_id: str,
timeout: Union[float, httpx.Timeout],
acompletion=None,
- litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]:
@@ -241,6 +241,7 @@ class PredibaseChatCompletion:
messages=messages,
optional_params=optional_params,
model=model,
+ litellm_params=litellm_params,
)
completion_url = ""
input_text = ""
diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py
index 8ef0eea173..9fbb9d6c9e 100644
--- a/litellm/llms/predibase/chat/transformation.py
+++ b/litellm/llms/predibase/chat/transformation.py
@@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/replicate/chat/handler.py b/litellm/llms/replicate/chat/handler.py
index d954416381..e4bb64fed7 100644
--- a/litellm/llms/replicate/chat/handler.py
+++ b/litellm/llms/replicate/chat/handler.py
@@ -141,6 +141,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
)
# Start a prediction and get the prediction URL
version_id = replicate_config.model_to_version_id(model)
diff --git a/litellm/llms/replicate/chat/transformation.py b/litellm/llms/replicate/chat/transformation.py
index 604e6eefe6..4c61086801 100644
--- a/litellm/llms/replicate/chat/transformation.py
+++ b/litellm/llms/replicate/chat/transformation.py
@@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/sagemaker/completion/handler.py b/litellm/llms/sagemaker/completion/handler.py
index 296689c31c..ebd96ac5b1 100644
--- a/litellm/llms/sagemaker/completion/handler.py
+++ b/litellm/llms/sagemaker/completion/handler.py
@@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM):
model: str,
data: dict,
messages: List[AllMessageValues],
+ litellm_params: dict,
optional_params: dict,
aws_region_name: str,
extra_headers: Optional[dict] = None,
@@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
)
request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers
@@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM):
data=data,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
@@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": _data,
"optional_params": optional_params,
+ "litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,
@@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": data,
"optional_params": optional_params,
+ "litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,
@@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": data,
"optional_params": optional_params,
+ "litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,
diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py
index df3d028c99..bfc0b6e5f6 100644
--- a/litellm/llms/sagemaker/completion/transformation.py
+++ b/litellm/llms/sagemaker/completion/transformation.py
@@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/snowflake/chat/transformation.py b/litellm/llms/snowflake/chat/transformation.py
index 574c4704cd..2b92911b05 100644
--- a/litellm/llms/snowflake/chat/transformation.py
+++ b/litellm/llms/snowflake/chat/transformation.py
@@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/topaz/image_variations/transformation.py b/litellm/llms/topaz/image_variations/transformation.py
index 4d14f1ad24..afbd89b9bc 100644
--- a/litellm/llms/topaz/image_variations/transformation.py
+++ b/litellm/llms/topaz/image_variations/transformation.py
@@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py
index 49126917f2..21fcf2eefb 100644
--- a/litellm/llms/triton/completion/transformation.py
+++ b/litellm/llms/triton/completion/transformation.py
@@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
diff --git a/litellm/llms/triton/embedding/transformation.py b/litellm/llms/triton/embedding/transformation.py
index 4744ec0834..8ab0277e36 100644
--- a/litellm/llms/triton/embedding/transformation.py
+++ b/litellm/llms/triton/embedding/transformation.py
@@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/vertex_ai/files/handler.py b/litellm/llms/vertex_ai/files/handler.py
index 87c1cb8320..a666a2c37f 100644
--- a/litellm/llms/vertex_ai/files/handler.py
+++ b/litellm/llms/vertex_ai/files/handler.py
@@ -1,3 +1,4 @@
+import asyncio
from typing import Any, Coroutine, Optional, Union
import httpx
@@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
-from .transformation import VertexAIFilesTransformation
+from .transformation import VertexAIJsonlFilesTransformation
-vertex_ai_files_transformation = VertexAIFilesTransformation()
+vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
class VertexAIFilesHandler(GCSBucketBase):
@@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase):
timeout=timeout,
max_retries=max_retries,
)
-
- return None # type: ignore
+ else:
+ return asyncio.run(
+ self.async_create_file(
+ create_file_data=create_file_data,
+ api_base=api_base,
+ vertex_credentials=vertex_credentials,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ timeout=timeout,
+ max_retries=max_retries,
+ )
+ )
diff --git a/litellm/llms/vertex_ai/files/transformation.py b/litellm/llms/vertex_ai/files/transformation.py
index 89c6ff9deb..c795367e48 100644
--- a/litellm/llms/vertex_ai/files/transformation.py
+++ b/litellm/llms/vertex_ai/files/transformation.py
@@ -1,7 +1,17 @@
import json
+import os
+import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
+from httpx import Headers, Response
+
+from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.llms.base_llm.files.transformation import (
+ BaseFilesConfig,
+ LiteLLMLoggingObj,
+)
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
@@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
from litellm.types.llms.openai import (
+ AllMessageValues,
CreateFileRequest,
FileTypes,
+ OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
PathLike,
)
+from litellm.types.llms.vertex_ai import GcsBucketResponse
+from litellm.types.utils import ExtractedFileData, LlmProviders
+
+from ..common_utils import VertexAIError
+from ..vertex_llm_base import VertexBase
-class VertexAIFilesTransformation(VertexGeminiConfig):
+class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
+ """
+ Config for VertexAI Files
+ """
+
+ def __init__(self):
+ self.jsonl_transformation = VertexAIJsonlFilesTransformation()
+ super().__init__()
+
+ @property
+ def custom_llm_provider(self) -> LlmProviders:
+ return LlmProviders.VERTEX_AI
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> dict:
+ if not api_key:
+ api_key, _ = self.get_access_token(
+ credentials=litellm_params.get("vertex_credentials"),
+ project_id=litellm_params.get("vertex_project"),
+ )
+ if not api_key:
+ raise ValueError("api_key is required")
+ headers["Authorization"] = f"Bearer {api_key}"
+ return headers
+
+ def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
+ """
+ Helper to extract content from various OpenAI file types and return as string.
+
+ Handles:
+ - Direct content (str, bytes, IO[bytes])
+ - Tuple formats: (filename, content, [content_type], [headers])
+ - PathLike objects
+ """
+ content: Union[str, bytes] = b""
+ # Extract file content from tuple if necessary
+ if isinstance(openai_file_content, tuple):
+ # Take the second element which is always the file content
+ file_content = openai_file_content[1]
+ else:
+ file_content = openai_file_content
+
+ # Handle different file content types
+ if isinstance(file_content, str):
+ # String content can be used directly
+ content = file_content
+ elif isinstance(file_content, bytes):
+ # Bytes content can be decoded
+ content = file_content
+ elif isinstance(file_content, PathLike): # PathLike
+ with open(str(file_content), "rb") as f:
+ content = f.read()
+ elif hasattr(file_content, "read"): # IO[bytes]
+ # File-like objects need to be read
+ content = file_content.read()
+
+ # Ensure content is string
+ if isinstance(content, bytes):
+ content = content.decode("utf-8")
+
+ return content
+
+ def _get_gcs_object_name_from_batch_jsonl(
+ self,
+ openai_jsonl_content: List[Dict[str, Any]],
+ ) -> str:
+ """
+ Gets a unique GCS object name for the VertexAI batch prediction job
+
+ named as: litellm-vertex-{model}-{uuid}
+ """
+ _model = openai_jsonl_content[0].get("body", {}).get("model", "")
+ if "publishers/google/models" not in _model:
+ _model = f"publishers/google/models/{_model}"
+ object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
+ return object_name
+
+ def get_object_name(
+ self, extracted_file_data: ExtractedFileData, purpose: str
+ ) -> str:
+ """
+ Get the object name for the request
+ """
+ extracted_file_data_content = extracted_file_data.get("content")
+
+ if extracted_file_data_content is None:
+ raise ValueError("file content is required")
+
+ if purpose == "batch":
+ ## 1. If jsonl, check if there's a model name
+ file_content = self._get_content_from_openai_file(
+ extracted_file_data_content
+ )
+
+ # Split into lines and parse each line as JSON
+ openai_jsonl_content = [
+ json.loads(line) for line in file_content.splitlines() if line.strip()
+ ]
+ if len(openai_jsonl_content) > 0:
+ return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
+
+ ## 2. If not jsonl, return the filename
+ filename = extracted_file_data.get("filename")
+ if filename:
+ return filename
+ ## 3. If no file name, return timestamp
+ return str(int(time.time()))
+
+ def get_complete_file_url(
+ self,
+ api_base: Optional[str],
+ api_key: Optional[str],
+ model: str,
+ optional_params: Dict,
+ litellm_params: Dict,
+ data: CreateFileRequest,
+ ) -> str:
+ """
+ Get the complete url for the request
+ """
+ bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME")
+ if not bucket_name:
+ raise ValueError("GCS bucket_name is required")
+ file_data = data.get("file")
+ purpose = data.get("purpose")
+ if file_data is None:
+ raise ValueError("file is required")
+ if purpose is None:
+ raise ValueError("purpose is required")
+ extracted_file_data = extract_file_data(file_data)
+ object_name = self.get_object_name(extracted_file_data, purpose)
+ endpoint = (
+ f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
+ )
+ api_base = api_base or "https://storage.googleapis.com"
+ if not api_base:
+ raise ValueError("api_base is required")
+
+ return f"{api_base}/{endpoint}"
+
+ def get_supported_openai_params(
+ self, model: str
+ ) -> List[OpenAICreateFileRequestOptionalParams]:
+ return []
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ return optional_params
+
+ def _map_openai_to_vertex_params(
+ self,
+ openai_request_body: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """
+ wrapper to call VertexGeminiConfig.map_openai_params
+ """
+ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
+ VertexGeminiConfig,
+ )
+
+ config = VertexGeminiConfig()
+ _model = openai_request_body.get("model", "")
+ vertex_params = config.map_openai_params(
+ model=_model,
+ non_default_params=openai_request_body,
+ optional_params={},
+ drop_params=False,
+ )
+ return vertex_params
+
+ def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
+ self, openai_jsonl_content: List[Dict[str, Any]]
+ ) -> List[Dict[str, Any]]:
+ """
+ Transforms OpenAI JSONL content to VertexAI JSONL content
+
+ jsonl body for vertex is {"request": }
+ Example Vertex jsonl
+ {"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
+ {"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
+ """
+
+ vertex_jsonl_content = []
+ for _openai_jsonl_content in openai_jsonl_content:
+ openai_request_body = _openai_jsonl_content.get("body") or {}
+ vertex_request_body = _transform_request_body(
+ messages=openai_request_body.get("messages", []),
+ model=openai_request_body.get("model", ""),
+ optional_params=self._map_openai_to_vertex_params(openai_request_body),
+ custom_llm_provider="vertex_ai",
+ litellm_params={},
+ cached_content=None,
+ )
+ vertex_jsonl_content.append({"request": vertex_request_body})
+ return vertex_jsonl_content
+
+ def transform_create_file_request(
+ self,
+ model: str,
+ create_file_data: CreateFileRequest,
+ optional_params: dict,
+ litellm_params: dict,
+ ) -> Union[bytes, str, dict]:
+ """
+ 2 Cases:
+ 1. Handle basic file upload
+ 2. Handle batch file upload (.jsonl)
+ """
+ file_data = create_file_data.get("file")
+ if file_data is None:
+ raise ValueError("file is required")
+ extracted_file_data = extract_file_data(file_data)
+ extracted_file_data_content = extracted_file_data.get("content")
+ if (
+ create_file_data.get("purpose") == "batch"
+ and extracted_file_data.get("content_type") == "application/jsonl"
+ and extracted_file_data_content is not None
+ ):
+ ## 1. If jsonl, check if there's a model name
+ file_content = self._get_content_from_openai_file(
+ extracted_file_data_content
+ )
+
+ # Split into lines and parse each line as JSON
+ openai_jsonl_content = [
+ json.loads(line) for line in file_content.splitlines() if line.strip()
+ ]
+ vertex_jsonl_content = (
+ self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
+ openai_jsonl_content
+ )
+ )
+ return json.dumps(vertex_jsonl_content)
+ elif isinstance(extracted_file_data_content, bytes):
+ return extracted_file_data_content
+ else:
+ raise ValueError("Unsupported file content type")
+
+ def transform_create_file_response(
+ self,
+ model: Optional[str],
+ raw_response: Response,
+ logging_obj: LiteLLMLoggingObj,
+ litellm_params: dict,
+ ) -> OpenAIFileObject:
+ """
+ Transform VertexAI File upload response into OpenAI-style FileObject
+ """
+ response_json = raw_response.json()
+
+ try:
+ response_object = GcsBucketResponse(**response_json) # type: ignore
+ except Exception as e:
+ raise VertexAIError(
+ status_code=raw_response.status_code,
+ message=f"Error reading GCS bucket response: {e}",
+ headers=raw_response.headers,
+ )
+
+ gcs_id = response_object.get("id", "")
+ # Remove the last numeric ID from the path
+ gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
+
+ return OpenAIFileObject(
+ purpose=response_object.get("purpose", "batch"),
+ id=f"gs://{gcs_id}",
+ filename=response_object.get("name", ""),
+ created_at=_convert_vertex_datetime_to_openai_datetime(
+ vertex_datetime=response_object.get("timeCreated", "")
+ ),
+ status="uploaded",
+ bytes=int(response_object.get("size", 0)),
+ object="file",
+ )
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[Dict, Headers]
+ ) -> BaseLLMException:
+ return VertexAIError(
+ status_code=status_code, message=error_message, headers=headers
+ )
+
+
+class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
"""
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
"""
diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py
index 749b6d9428..b3b7857ea1 100644
--- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py
+++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py
@@ -240,6 +240,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
gtool_func_declarations = []
googleSearch: Optional[dict] = None
googleSearchRetrieval: Optional[dict] = None
+ enterpriseWebSearch: Optional[dict] = None
code_execution: Optional[dict] = None
# remove 'additionalProperties' from tools
value = _remove_additional_properties(value)
@@ -273,6 +274,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
googleSearch = tool["googleSearch"]
elif tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
+ elif tool.get("enterpriseWebSearch", None) is not None:
+ enterpriseWebSearch = tool["enterpriseWebSearch"]
elif tool.get("code_execution", None) is not None:
code_execution = tool["code_execution"]
elif openai_function_object is not None:
@@ -299,6 +302,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
_tools["googleSearch"] = googleSearch
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
+ if enterpriseWebSearch is not None:
+ _tools["enterpriseWebSearch"] = enterpriseWebSearch
if code_execution is not None:
_tools["code_execution"] = code_execution
return [_tools]
@@ -900,6 +905,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
+ litellm_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
@@ -1017,7 +1023,7 @@ class VertexLLM(VertexBase):
logging_obj,
stream,
optional_params: dict,
- litellm_params=None,
+ litellm_params: dict,
logger_fn=None,
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
@@ -1058,6 +1064,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
)
## LOGGING
@@ -1144,6 +1151,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
)
request_body = await async_transform_request_body(**data) # type: ignore
@@ -1317,6 +1325,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
)
## TRANSFORMATION ##
diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
index 88d7339449..8aebd83cc4 100644
--- a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
+++ b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
@@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM):
optional_params=optional_params,
api_key=auth_header,
api_base=api_base,
+ litellm_params=litellm_params,
)
## LOGGING
diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py
index afa58c7e5c..5bf02ad765 100644
--- a/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py
+++ b/litellm/llms/vertex_ai/multimodal_embeddings/transformation.py
@@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/vertex_ai/vertex_llm_base.py b/litellm/llms/vertex_ai/vertex_llm_base.py
index 994e46b50b..8f3037c791 100644
--- a/litellm/llms/vertex_ai/vertex_llm_base.py
+++ b/litellm/llms/vertex_ai/vertex_llm_base.py
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
-from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
@@ -22,7 +21,7 @@ else:
GoogleCredentialsObject = Any
-class VertexBase(BaseLLM):
+class VertexBase:
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None
diff --git a/litellm/llms/voyage/embedding/transformation.py b/litellm/llms/voyage/embedding/transformation.py
index df6ef91a41..91811e0392 100644
--- a/litellm/llms/voyage/embedding/transformation.py
+++ b/litellm/llms/voyage/embedding/transformation.py
@@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
diff --git a/litellm/llms/watsonx/chat/handler.py b/litellm/llms/watsonx/chat/handler.py
index aeb0167595..45378c5529 100644
--- a/litellm/llms/watsonx/chat/handler.py
+++ b/litellm/llms/watsonx/chat/handler.py
@@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
messages=messages,
optional_params=optional_params,
api_key=api_key,
+ litellm_params=litellm_params,
)
## UPDATE PAYLOAD (optional params)
diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py
index 4916cd1c75..e13e015add 100644
--- a/litellm/llms/watsonx/common_utils.py
+++ b/litellm/llms/watsonx/common_utils.py
@@ -165,6 +165,7 @@ class IBMWatsonXMixin:
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
+ litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
diff --git a/litellm/main.py b/litellm/main.py
index cd7d255e21..3f1d9a1e76 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
+ litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "bedrock":
if isinstance(input, str):
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index 79aa57f466..7e5be4dc6b 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -2409,25 +2409,26 @@
"max_tokens": 4096,
"max_input_tokens": 131072,
"max_output_tokens": 4096,
- "input_cost_per_token": 0,
- "output_cost_per_token": 0,
+ "input_cost_per_token": 0.000000075,
+ "output_cost_per_token": 0.0000003,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_function_calling": true,
- "source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
+ "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
},
"azure_ai/Phi-4-multimodal-instruct": {
"max_tokens": 4096,
"max_input_tokens": 131072,
"max_output_tokens": 4096,
- "input_cost_per_token": 0,
- "output_cost_per_token": 0,
+ "input_cost_per_token": 0.00000008,
+ "input_cost_per_audio_token": 0.000004,
+ "output_cost_per_token": 0.00032,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_audio_input": true,
"supports_function_calling": true,
"supports_vision": true,
- "source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
+ "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
},
"azure_ai/Phi-4": {
"max_tokens": 16384,
@@ -3467,7 +3468,7 @@
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.000004,
"cache_creation_input_token_cost": 0.000001,
- "cache_read_input_token_cost": 0.0000008,
+ "cache_read_input_token_cost": 0.00000008,
"litellm_provider": "anthropic",
"mode": "chat",
"supports_function_calling": true,
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index ae4bdc7b8c..b64ff6c827 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -1625,6 +1625,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
model_max_budget: Optional[Dict] = {}
model_spend: Optional[Dict] = {}
user_email: Optional[str] = None
+ user_alias: Optional[str] = None
models: list = []
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
index a6b1b3e614..563d0cb543 100644
--- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
+++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
@@ -4,16 +4,26 @@ import json
import uuid
from base64 import b64encode
from datetime import datetime
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse
import httpx
-from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
+from fastapi import (
+ APIRouter,
+ Depends,
+ HTTPException,
+ Request,
+ Response,
+ UploadFile,
+ status,
+)
from fastapi.responses import StreamingResponse
+from starlette.datastructures import UploadFile as StarletteUploadFile
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import (
ConfigFieldInfo,
@@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers:
)
return response
+ @staticmethod
+ async def non_streaming_http_request_handler(
+ request: Request,
+ async_client: httpx.AsyncClient,
+ url: httpx.URL,
+ headers: dict,
+ requested_query_params: Optional[dict] = None,
+ _parsed_body: Optional[dict] = None,
+ ) -> httpx.Response:
+ """
+ Handle non-streaming HTTP requests
+
+ Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests
+ """
+ if request.method == "GET":
+ response = await async_client.request(
+ method=request.method,
+ url=url,
+ headers=headers,
+ params=requested_query_params,
+ )
+ elif HttpPassThroughEndpointHelpers.is_multipart(request) is True:
+ return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
+ request=request,
+ async_client=async_client,
+ url=url,
+ headers=headers,
+ requested_query_params=requested_query_params,
+ )
+ else:
+ # Generic httpx method
+ response = await async_client.request(
+ method=request.method,
+ url=url,
+ headers=headers,
+ params=requested_query_params,
+ json=_parsed_body,
+ )
+ return response
+
+ @staticmethod
+ def is_multipart(request: Request) -> bool:
+ """Check if the request is a multipart/form-data request"""
+ return "multipart/form-data" in request.headers.get("content-type", "")
+
+ @staticmethod
+ async def _build_request_files_from_upload_file(
+ upload_file: Union[UploadFile, StarletteUploadFile],
+ ) -> Tuple[Optional[str], bytes, Optional[str]]:
+ """Build a request files dict from an UploadFile object"""
+ file_content = await upload_file.read()
+ return (upload_file.filename, file_content, upload_file.content_type)
+
+ @staticmethod
+ async def make_multipart_http_request(
+ request: Request,
+ async_client: httpx.AsyncClient,
+ url: httpx.URL,
+ headers: dict,
+ requested_query_params: Optional[dict] = None,
+ ) -> httpx.Response:
+ """Process multipart/form-data requests, handling both files and form fields"""
+ form_data = await request.form()
+ files = {}
+ form_data_dict = {}
+
+ for field_name, field_value in form_data.items():
+ if isinstance(field_value, (StarletteUploadFile, UploadFile)):
+ files[field_name] = (
+ await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
+ upload_file=field_value
+ )
+ )
+ else:
+ form_data_dict[field_name] = field_value
+
+ response = await async_client.request(
+ method=request.method,
+ url=url,
+ headers=headers,
+ params=requested_query_params,
+ files=files,
+ data=form_data_dict,
+ )
+ return response
+
async def pass_through_request( # noqa: PLR0915
request: Request,
@@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915
start_time = datetime.now()
logging_obj = Logging(
model="unknown",
- messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
+ messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
stream=False,
call_type="pass_through_endpoint",
start_time=start_time,
@@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
# combine url with query params for logging
-
requested_query_params: Optional[dict] = (
query_params or request.query_params.__dict__
)
@@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915
logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call(
- input=[{"role": "user", "content": json.dumps(_parsed_body)}],
+ input=[{"role": "user", "content": safe_dumps(_parsed_body)}],
api_key="",
additional_args={
"complete_input_dict": _parsed_body,
@@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915
)
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
- if request.method == "GET":
- response = await async_client.request(
- method=request.method,
+ response = (
+ await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
+ request=request,
+ async_client=async_client,
url=url,
headers=headers,
- params=requested_query_params,
+ requested_query_params=requested_query_params,
+ _parsed_body=_parsed_body,
)
- else:
- response = await async_client.request(
- method=request.method,
- url=url,
- headers=headers,
- params=requested_query_params,
- json=_parsed_body,
- )
-
+ )
verbose_proxy_logger.debug("response.headers= %s", response.headers)
if _is_streaming_response(response) is True:
diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py
index 7fa167938f..55273371fc 100644
--- a/litellm/types/llms/vertex_ai.py
+++ b/litellm/types/llms/vertex_ai.py
@@ -187,6 +187,7 @@ class Tools(TypedDict, total=False):
function_declarations: List[FunctionDeclaration]
googleSearch: dict
googleSearchRetrieval: dict
+ enterpriseWebSearch: dict
code_execution: dict
retrieval: Retrieval
@@ -497,6 +498,51 @@ class OutputConfig(TypedDict, total=False):
gcsDestination: GcsDestination
+class GcsBucketResponse(TypedDict):
+ """
+ TypedDict for GCS bucket upload response
+
+ Attributes:
+ kind: The kind of item this is. For objects, this is always storage#object
+ id: The ID of the object
+ selfLink: The link to this object
+ mediaLink: The link to download the object
+ name: The name of the object
+ bucket: The name of the bucket containing this object
+ generation: The content generation of this object
+ metageneration: The metadata generation of this object
+ contentType: The content type of the object
+ storageClass: The storage class of the object
+ size: The size of the object in bytes
+ md5Hash: The MD5 hash of the object
+ crc32c: The CRC32c checksum of the object
+ etag: The ETag of the object
+ timeCreated: The creation time of the object
+ updated: The last update time of the object
+ timeStorageClassUpdated: The time the storage class was last updated
+ timeFinalized: The time the object was finalized
+ """
+
+ kind: Literal["storage#object"]
+ id: str
+ selfLink: str
+ mediaLink: str
+ name: str
+ bucket: str
+ generation: str
+ metageneration: str
+ contentType: str
+ storageClass: str
+ size: str
+ md5Hash: str
+ crc32c: str
+ etag: str
+ timeCreated: str
+ updated: str
+ timeStorageClassUpdated: str
+ timeFinalized: str
+
+
class VertexAIBatchPredictionJob(TypedDict):
displayName: str
model: str
diff --git a/litellm/types/utils.py b/litellm/types/utils.py
index 8439037758..6f0c26d301 100644
--- a/litellm/types/utils.py
+++ b/litellm/types/utils.py
@@ -2,7 +2,7 @@ import json
import time
import uuid
from enum import Enum
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union
from aiohttp import FormData
from openai._models import BaseModel as OpenAIObject
@@ -2170,3 +2170,20 @@ class CreateCredentialItem(CredentialBase):
if not values.get("credential_values") and not values.get("model_id"):
raise ValueError("Either credential_values or model_id must be set")
return values
+
+
+class ExtractedFileData(TypedDict):
+ """
+ TypedDict for storing processed file data
+
+ Attributes:
+ filename: Name of the file if provided
+ content: The file content in bytes
+ content_type: MIME type of the file
+ headers: Any additional headers for the file
+ """
+
+ filename: Optional[str]
+ content: bytes
+ content_type: Optional[str]
+ headers: Mapping[str, str]
diff --git a/litellm/utils.py b/litellm/utils.py
index f807990f60..f809d8a77b 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -6517,6 +6517,10 @@ class ProviderConfigManager:
)
return GoogleAIStudioFilesHandler()
+ elif LlmProviders.VERTEX_AI == provider:
+ from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig
+
+ return VertexAIFilesConfig()
return None
diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json
index 79aa57f466..7e5be4dc6b 100644
--- a/model_prices_and_context_window.json
+++ b/model_prices_and_context_window.json
@@ -2409,25 +2409,26 @@
"max_tokens": 4096,
"max_input_tokens": 131072,
"max_output_tokens": 4096,
- "input_cost_per_token": 0,
- "output_cost_per_token": 0,
+ "input_cost_per_token": 0.000000075,
+ "output_cost_per_token": 0.0000003,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_function_calling": true,
- "source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
+ "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
},
"azure_ai/Phi-4-multimodal-instruct": {
"max_tokens": 4096,
"max_input_tokens": 131072,
"max_output_tokens": 4096,
- "input_cost_per_token": 0,
- "output_cost_per_token": 0,
+ "input_cost_per_token": 0.00000008,
+ "input_cost_per_audio_token": 0.000004,
+ "output_cost_per_token": 0.00032,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_audio_input": true,
"supports_function_calling": true,
"supports_vision": true,
- "source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
+ "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
},
"azure_ai/Phi-4": {
"max_tokens": 16384,
@@ -3467,7 +3468,7 @@
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.000004,
"cache_creation_input_token_cost": 0.000001,
- "cache_read_input_token_cost": 0.0000008,
+ "cache_read_input_token_cost": 0.00000008,
"litellm_provider": "anthropic",
"mode": "chat",
"supports_function_calling": true,
diff --git a/tests/batches_tests/test_openai_batches_and_files.py b/tests/batches_tests/test_openai_batches_and_files.py
index 4669a2def6..b2826419e8 100644
--- a/tests/batches_tests/test_openai_batches_and_files.py
+++ b/tests/batches_tests/test_openai_batches_and_files.py
@@ -423,25 +423,35 @@ mock_vertex_batch_response = {
@pytest.mark.asyncio
-async def test_avertex_batch_prediction():
- with patch(
+async def test_avertex_batch_prediction(monkeypatch):
+ monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local")
+ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
+
+ client = AsyncHTTPHandler()
+
+ async def mock_side_effect(*args, **kwargs):
+ print("args", args, "kwargs", kwargs)
+ url = kwargs.get("url", "")
+ if "files" in url:
+ mock_response.json.return_value = mock_file_response
+ elif "batch" in url:
+ mock_response.json.return_value = mock_vertex_batch_response
+ mock_response.status_code = 200
+ return mock_response
+
+ with patch.object(
+ client, "post", side_effect=mock_side_effect
+ ) as mock_post, patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
- ) as mock_post:
+ ) as mock_global_post:
# Configure mock responses
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
# Set up different responses for different API calls
- async def mock_side_effect(*args, **kwargs):
- url = kwargs.get("url", "")
- if "files" in url:
- mock_response.json.return_value = mock_file_response
- elif "batch" in url:
- mock_response.json.return_value = mock_vertex_batch_response
- mock_response.status_code = 200
- return mock_response
-
+
mock_post.side_effect = mock_side_effect
+ mock_global_post.side_effect = mock_side_effect
# load_vertex_ai_credentials()
litellm.set_verbose = True
@@ -455,6 +465,7 @@ async def test_avertex_batch_prediction():
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider="vertex_ai",
+ client=client
)
print("Response from creating file=", file_obj)
diff --git a/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py b/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py
new file mode 100644
index 0000000000..43d4dd9cd8
--- /dev/null
+++ b/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py
@@ -0,0 +1,116 @@
+import json
+import os
+import sys
+from io import BytesIO
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import httpx
+import pytest
+from fastapi import Request, UploadFile
+from fastapi.testclient import TestClient
+from starlette.datastructures import Headers
+from starlette.datastructures import UploadFile as StarletteUploadFile
+
+sys.path.insert(
+ 0, os.path.abspath("../../..")
+) # Adds the parent directory to the system path
+
+from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
+ HttpPassThroughEndpointHelpers,
+)
+
+
+# Test is_multipart
+def test_is_multipart():
+ # Test with multipart content type
+ request = MagicMock(spec=Request)
+ request.headers = Headers({"content-type": "multipart/form-data; boundary=123"})
+ assert HttpPassThroughEndpointHelpers.is_multipart(request) is True
+
+ # Test with non-multipart content type
+ request.headers = Headers({"content-type": "application/json"})
+ assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
+
+ # Test with no content type
+ request.headers = Headers({})
+ assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
+
+
+# Test _build_request_files_from_upload_file
+@pytest.mark.asyncio
+async def test_build_request_files_from_upload_file():
+ # Test with FastAPI UploadFile
+ file_content = b"test content"
+ file = BytesIO(file_content)
+ # Create SpooledTemporaryFile with content type headers
+ headers = {"content-type": "text/plain"}
+ upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
+ upload_file.read = AsyncMock(return_value=file_content)
+
+ result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
+ upload_file
+ )
+ assert result == ("test.txt", file_content, "text/plain")
+
+ # Test with Starlette UploadFile
+ file2 = BytesIO(file_content)
+ starlette_file = StarletteUploadFile(
+ file=file2,
+ filename="test2.txt",
+ headers=Headers({"content-type": "text/plain"}),
+ )
+ starlette_file.read = AsyncMock(return_value=file_content)
+
+ result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
+ starlette_file
+ )
+ assert result == ("test2.txt", file_content, "text/plain")
+
+
+# Test make_multipart_http_request
+@pytest.mark.asyncio
+async def test_make_multipart_http_request():
+ # Mock request with file and form field
+ request = MagicMock(spec=Request)
+ request.method = "POST"
+
+ # Mock form data
+ file_content = b"test file content"
+ file = BytesIO(file_content)
+ # Create SpooledTemporaryFile with content type headers
+ headers = {"content-type": "text/plain"}
+ upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
+ upload_file.read = AsyncMock(return_value=file_content)
+
+ form_data = {"file": upload_file, "text_field": "test value"}
+ request.form = AsyncMock(return_value=form_data)
+
+ # Mock httpx client
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.headers = {}
+
+ async_client = MagicMock()
+ async_client.request = AsyncMock(return_value=mock_response)
+
+ # Test the function
+ response = await HttpPassThroughEndpointHelpers.make_multipart_http_request(
+ request=request,
+ async_client=async_client,
+ url=httpx.URL("http://test.com"),
+ headers={},
+ requested_query_params=None,
+ )
+
+ # Verify the response
+ assert response == mock_response
+
+ # Verify the client call
+ async_client.request.assert_called_once()
+ call_args = async_client.request.call_args[1]
+
+ assert call_args["method"] == "POST"
+ assert str(call_args["url"]) == "http://test.com"
+ assert isinstance(call_args["files"], dict)
+ assert isinstance(call_args["data"], dict)
+ assert call_args["data"]["text_field"] == "test value"
diff --git a/tests/llm_translation/test_huggingface_chat_completion.py b/tests/llm_translation/test_huggingface_chat_completion.py
index 9f1e89aeb1..7d498b96df 100644
--- a/tests/llm_translation/test_huggingface_chat_completion.py
+++ b/tests/llm_translation/test_huggingface_chat_completion.py
@@ -323,7 +323,8 @@ class TestHuggingFace(BaseLLMChatTest):
model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Hello"}],
optional_params={},
- api_key="test_api_key"
+ api_key="test_api_key",
+ litellm_params={}
)
assert headers["Authorization"] == "Bearer test_api_key"
diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py
index d821fb415e..9118d94a6f 100644
--- a/tests/llm_translation/test_vertex.py
+++ b/tests/llm_translation/test_vertex.py
@@ -141,6 +141,7 @@ def test_build_vertex_schema():
[
([{"googleSearch": {}}], "googleSearch"),
([{"googleSearchRetrieval": {}}], "googleSearchRetrieval"),
+ ([{"enterpriseWebSearch": {}}], "enterpriseWebSearch"),
([{"code_execution": {}}], "code_execution"),
],
)
diff --git a/tests/local_testing/example.jsonl b/tests/local_testing/example.jsonl
new file mode 100644
index 0000000000..fc3ca40808
--- /dev/null
+++ b/tests/local_testing/example.jsonl
@@ -0,0 +1,2 @@
+{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}
+{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}
diff --git a/tests/local_testing/test_gcs_bucket.py b/tests/local_testing/test_gcs_bucket.py
index 1a8deed8a8..b64475c227 100644
--- a/tests/local_testing/test_gcs_bucket.py
+++ b/tests/local_testing/test_gcs_bucket.py
@@ -21,7 +21,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket import (
StandardLoggingPayload,
)
from litellm.types.utils import StandardCallbackDynamicParams
-
+from unittest.mock import patch
verbose_logger.setLevel(logging.DEBUG)
@@ -687,3 +687,63 @@ async def test_basic_gcs_logger_with_folder_in_bucket_name():
# clean up
if old_bucket_name is not None:
os.environ["GCS_BUCKET_NAME"] = old_bucket_name
+
+@pytest.mark.skip(reason="This test is flaky on ci/cd")
+def test_create_file_e2e():
+ """
+ Asserts 'create_file' is called with the correct arguments
+ """
+ load_vertex_ai_credentials()
+ test_file_content = b"test audio content"
+ test_file = ("test.wav", test_file_content, "audio/wav")
+
+ from litellm import create_file
+ response = create_file(
+ file=test_file,
+ purpose="user_data",
+ custom_llm_provider="vertex_ai",
+ )
+ print("response", response)
+ assert response is not None
+
+@pytest.mark.skip(reason="This test is flaky on ci/cd")
+def test_create_file_e2e_jsonl():
+ """
+ Asserts 'create_file' is called with the correct arguments
+ """
+ load_vertex_ai_credentials()
+ from litellm.llms.custom_httpx.http_handler import HTTPHandler
+
+ client = HTTPHandler()
+
+ example_jsonl = [{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}},{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}]
+
+ # Create and write to the file
+ file_path = "example.jsonl"
+ with open(file_path, "w") as f:
+ for item in example_jsonl:
+ f.write(json.dumps(item) + "\n")
+
+ # Verify file content
+ with open(file_path, "r") as f:
+ content = f.read()
+ print("File content:", content)
+ assert len(content) > 0, "File is empty"
+
+ from litellm import create_file
+ with patch.object(client, "post") as mock_create_file:
+ try:
+ response = create_file(
+ file=open(file_path, "rb"),
+ purpose="user_data",
+ custom_llm_provider="vertex_ai",
+ client=client,
+ )
+ except Exception as e:
+ print("error", e)
+
+ mock_create_file.assert_called_once()
+
+ print(f"kwargs: {mock_create_file.call_args.kwargs}")
+
+ assert mock_create_file.call_args.kwargs["data"] is not None and len(mock_create_file.call_args.kwargs["data"]) > 0
\ No newline at end of file
diff --git a/tests/pass_through_tests/test_openai_assistants_passthrough.py b/tests/pass_through_tests/test_openai_assistants_passthrough.py
index 694d3c090e..40361ab39f 100644
--- a/tests/pass_through_tests/test_openai_assistants_passthrough.py
+++ b/tests/pass_through_tests/test_openai_assistants_passthrough.py
@@ -2,14 +2,31 @@ import pytest
import openai
import aiohttp
import asyncio
+import tempfile
from typing_extensions import override
from openai import AssistantEventHandler
+
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
+def test_pass_through_file_operations():
+ # Create a temporary file
+ with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file:
+ temp_file.write("This is a test file for the OpenAI Assistants API.")
+ temp_file.flush()
+
+ # create a file
+ file = client.files.create(
+ file=open(temp_file.name, "rb"),
+ purpose="assistants",
+ )
+ print("file created", file)
+
+ # delete the file
+ delete_file = client.files.delete(file.id)
+ print("file deleted", delete_file)
def test_openai_assistants_e2e_operations():
-
assistant = client.beta.assistants.create(
name="Math Tutor",
instructions="You are a personal math tutor. Write and run code to answer math questions.",
diff --git a/ui/litellm-dashboard/src/components/all_keys_table.tsx b/ui/litellm-dashboard/src/components/all_keys_table.tsx
index b0313c241f..3a2bf61c5f 100644
--- a/ui/litellm-dashboard/src/components/all_keys_table.tsx
+++ b/ui/litellm-dashboard/src/components/all_keys_table.tsx
@@ -13,9 +13,12 @@ import { Organization, userListCall } from "./networking";
import { createTeamSearchFunction } from "./key_team_helpers/team_search_fn";
import { createOrgSearchFunction } from "./key_team_helpers/organization_search_fn";
import { useFilterLogic } from "./key_team_helpers/filter_logic";
+import { Setter } from "@/types";
+import { updateExistingKeys } from "@/utils/dataUtils";
interface AllKeysTableProps {
keys: KeyResponse[];
+ setKeys: Setter;
isLoading?: boolean;
pagination: {
currentPage: number;
@@ -87,6 +90,7 @@ const TeamFilter = ({
*/
export function AllKeysTable({
keys,
+ setKeys,
isLoading = false,
pagination,
onPageChange,
@@ -364,6 +368,23 @@ export function AllKeysTable({
keyId={selectedKeyId}
onClose={() => setSelectedKeyId(null)}
keyData={keys.find(k => k.token === selectedKeyId)}
+ onKeyDataUpdate={(updatedKeyData) => {
+ setKeys(keys => keys.map(key => {
+ if (key.token === updatedKeyData.token) {
+ // The shape of key is different from that of
+ // updatedKeyData(received from keyUpdateCall in networking.tsx).
+ // Hence, we can't replace key with updatedKeys since it might lead
+ // to unintended bugs/behaviors.
+ // So instead, we only update fields that are present in both.
+ return updateExistingKeys(key, updatedKeyData)
+ }
+
+ return key
+ }))
+ }}
+ onDelete={() => {
+ setKeys(keys => keys.filter(key => key.token !== selectedKeyId))
+ }}
accessToken={accessToken}
userID={userID}
userRole={userRole}
diff --git a/ui/litellm-dashboard/src/components/key_info_view.tsx b/ui/litellm-dashboard/src/components/key_info_view.tsx
index 9d50be6cf7..b7ebdc651a 100644
--- a/ui/litellm-dashboard/src/components/key_info_view.tsx
+++ b/ui/litellm-dashboard/src/components/key_info_view.tsx
@@ -27,13 +27,15 @@ interface KeyInfoViewProps {
keyId: string;
onClose: () => void;
keyData: KeyResponse | undefined;
+ onKeyDataUpdate?: (data: Partial) => void;
+ onDelete?: () => void;
accessToken: string | null;
userID: string | null;
userRole: string | null;
teams: any[] | null;
}
-export default function KeyInfoView({ keyId, onClose, keyData, accessToken, userID, userRole, teams }: KeyInfoViewProps) {
+export default function KeyInfoView({ keyId, onClose, keyData, accessToken, userID, userRole, teams, onKeyDataUpdate, onDelete }: KeyInfoViewProps) {
const [isEditing, setIsEditing] = useState(false);
const [form] = Form.useForm();
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
@@ -93,6 +95,9 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
}
const newKeyValues = await keyUpdateCall(accessToken, formValues);
+ if (onKeyDataUpdate) {
+ onKeyDataUpdate(newKeyValues)
+ }
message.success("Key updated successfully");
setIsEditing(false);
// Refresh key data here if needed
@@ -107,6 +112,9 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
if (!accessToken) return;
await keyDeleteCall(accessToken as string, keyData.token);
message.success("Key deleted successfully");
+ if (onDelete) {
+ onDelete()
+ }
onClose();
} catch (error) {
console.error("Error deleting the key:", error);
diff --git a/ui/litellm-dashboard/src/components/key_team_helpers/key_list.tsx b/ui/litellm-dashboard/src/components/key_team_helpers/key_list.tsx
index 4c2a18d2b5..4ca0ea5720 100644
--- a/ui/litellm-dashboard/src/components/key_team_helpers/key_list.tsx
+++ b/ui/litellm-dashboard/src/components/key_team_helpers/key_list.tsx
@@ -1,5 +1,6 @@
import { useState, useEffect } from 'react';
import { keyListCall, Organization } from '../networking';
+import { Setter } from '@/types';
export interface Team {
team_id: string;
@@ -94,13 +95,14 @@ totalPages: number;
totalCount: number;
}
+
interface UseKeyListReturn {
keys: KeyResponse[];
isLoading: boolean;
error: Error | null;
pagination: PaginationData;
refresh: (params?: Record) => Promise;
-setKeys: (newKeysOrUpdater: KeyResponse[] | ((prevKeys: KeyResponse[]) => KeyResponse[])) => void;
+setKeys: Setter;
}
const useKeyList = ({
diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx
index ac79237fb8..025f0c72c4 100644
--- a/ui/litellm-dashboard/src/components/networking.tsx
+++ b/ui/litellm-dashboard/src/components/networking.tsx
@@ -4,6 +4,7 @@
import { all_admin_roles } from "@/utils/roles";
import { message } from "antd";
import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types";
+import { Team } from "./key_team_helpers/key_list";
const isLocal = process.env.NODE_ENV === "development";
export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
@@ -2983,7 +2984,7 @@ export const teamUpdateCall = async (
console.error("Error response from the server:", errorData);
throw new Error("Network response was not ok");
}
- const data = await response.json();
+ const data = await response.json() as { data: Team, team_id: string };
console.log("Update Team Response:", data);
return data;
// Handle success - you might want to update some state or UI based on the created key
diff --git a/ui/litellm-dashboard/src/components/team/team_info.tsx b/ui/litellm-dashboard/src/components/team/team_info.tsx
index e04680b53a..20e9d23ccf 100644
--- a/ui/litellm-dashboard/src/components/team/team_info.tsx
+++ b/ui/litellm-dashboard/src/components/team/team_info.tsx
@@ -30,6 +30,7 @@ import { PencilAltIcon, PlusIcon, TrashIcon } from "@heroicons/react/outline";
import MemberModal from "./edit_membership";
import UserSearchModal from "@/components/common_components/user_search_modal";
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
+import { Team } from "../key_team_helpers/key_list";
interface TeamData {
@@ -69,6 +70,7 @@ interface TeamInfoProps {
is_proxy_admin: boolean;
userModels: string[];
editTeam: boolean;
+ onUpdate?: (team: Team) => void
}
const TeamInfoView: React.FC = ({
@@ -78,7 +80,8 @@ const TeamInfoView: React.FC = ({
is_team_admin,
is_proxy_admin,
userModels,
- editTeam
+ editTeam,
+ onUpdate
}) => {
const [teamData, setTeamData] = useState(null);
const [loading, setLoading] = useState(true);
@@ -199,7 +202,10 @@ const TeamInfoView: React.FC = ({
};
const response = await teamUpdateCall(accessToken, updateData);
-
+ if (onUpdate) {
+ onUpdate(response.data)
+ }
+
message.success("Team settings updated successfully");
setIsEditing(false);
fetchTeamInfo();
diff --git a/ui/litellm-dashboard/src/components/teams.tsx b/ui/litellm-dashboard/src/components/teams.tsx
index 6f516f06e2..7e3b607267 100644
--- a/ui/litellm-dashboard/src/components/teams.tsx
+++ b/ui/litellm-dashboard/src/components/teams.tsx
@@ -84,6 +84,7 @@ import {
modelAvailableCall,
teamListCall
} from "./networking";
+import { updateExistingKeys } from "@/utils/dataUtils";
const getOrganizationModels = (organization: Organization | null, userModels: string[]) => {
let tempModelsToPick = [];
@@ -321,6 +322,22 @@ const Teams: React.FC = ({
{selectedTeamId ? (
{
+ setTeams(teams => {
+ if (teams == null) {
+ return teams;
+ }
+
+ return teams.map(team => {
+ if (data.team_id === team.team_id) {
+ return updateExistingKeys(team, data)
+ }
+
+ return team
+ })
+ })
+
+ }}
onClose={() => {
setSelectedTeamId(null);
setEditTeam(false);
diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx
index f3661c8c64..57467efa18 100644
--- a/ui/litellm-dashboard/src/components/view_key_table.tsx
+++ b/ui/litellm-dashboard/src/components/view_key_table.tsx
@@ -418,6 +418,7 @@ const ViewKeyTable: React.FC = ({
= (newValueOrUpdater: T | ((previousValue: T) => T)) => void
\ No newline at end of file
diff --git a/ui/litellm-dashboard/src/utils/dataUtils.ts b/ui/litellm-dashboard/src/utils/dataUtils.ts
new file mode 100644
index 0000000000..f51940f2ef
--- /dev/null
+++ b/ui/litellm-dashboard/src/utils/dataUtils.ts
@@ -0,0 +1,14 @@
+export function updateExistingKeys(
+ target: Source,
+ source: Object
+): Source {
+ const clonedTarget = structuredClone(target);
+
+ for (const [key, value] of Object.entries(source)) {
+ if (key in clonedTarget) {
+ (clonedTarget as any)[key] = value;
+ }
+ }
+
+ return clonedTarget;
+}