VertexAI non-jsonl file storage support (#9781)

* test: add initial e2e test

* fix(vertex_ai/files): initial commit adding sync file create support

* refactor: initial commit of vertex ai non-jsonl files reaching gcp endpoint

* fix(vertex_ai/files/transformation.py): initial working commit of non-jsonl file call reaching backend endpoint

* fix(vertex_ai/files/transformation.py): working e2e non-jsonl file upload

* test: working e2e jsonl call

* test: unit testing for jsonl file creation

* fix(vertex_ai/transformation.py): reset file pointer after read

allow multiple reads on same file object

* fix: fix linting errors

* fix: fix ruff linting errors

* fix: fix import

* fix: fix linting error

* fix: fix linting error

* fix(vertex_ai/files/transformation.py): fix linting error

* test: update test

* test: update tests

* fix: fix linting errors

* fix: fix test

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-04-09 14:01:48 -07:00 committed by GitHub
parent 93532e00db
commit 6ba3c4a4f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
64 changed files with 780 additions and 185 deletions

View file

@ -110,5 +110,8 @@ def get_litellm_params(
"azure_password": kwargs.get("azure_password"),
"max_retries": max_retries,
"timeout": kwargs.get("timeout"),
"bucket_name": kwargs.get("bucket_name"),
"vertex_credentials": kwargs.get("vertex_credentials"),
"vertex_project": kwargs.get("vertex_project"),
}
return litellm_params

View file

@ -2,7 +2,10 @@
Common utility functions used for translating messages across providers
"""
from typing import Dict, List, Literal, Optional, Union, cast
import io
import mimetypes
from os import PathLike
from typing import Dict, List, Literal, Mapping, Optional, Union, cast
from litellm.types.llms.openai import (
AllMessageValues,
@ -10,7 +13,13 @@ from litellm.types.llms.openai import (
ChatCompletionFileObject,
ChatCompletionUserMessage,
)
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
from litellm.types.utils import (
Choices,
ExtractedFileData,
FileTypes,
ModelResponse,
StreamingChoices,
)
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
content="Please continue.", role="user"
@ -350,6 +359,68 @@ def update_messages_with_model_file_ids(
return messages
def extract_file_data(file_data: FileTypes) -> ExtractedFileData:
"""
Extracts and processes file data from various input formats.
Args:
file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content
Returns:
ExtractedFileData containing:
- filename: Name of the file if provided
- content: The file content in bytes
- content_type: MIME type of the file
- headers: Any additional headers
"""
# Parse the file_data based on its type
filename = None
file_content = None
content_type = None
file_headers: Mapping[str, str] = {}
if isinstance(file_data, tuple):
if len(file_data) == 2:
filename, file_content = file_data
elif len(file_data) == 3:
filename, file_content, content_type = file_data
elif len(file_data) == 4:
filename, file_content, content_type, file_headers = file_data
else:
file_content = file_data
# Convert content to bytes
if isinstance(file_content, (str, PathLike)):
# If it's a path, open and read the file
with open(file_content, "rb") as f:
content = f.read()
elif isinstance(file_content, io.IOBase):
# If it's a file-like object
content = file_content.read()
if isinstance(content, str):
content = content.encode("utf-8")
# Reset file pointer to beginning
file_content.seek(0)
elif isinstance(file_content, bytes):
content = file_content
else:
raise ValueError(f"Unsupported file content type: {type(file_content)}")
# Use provided content type or guess based on filename
if not content_type:
content_type = (
mimetypes.guess_type(filename)[0]
if filename
else "application/octet-stream"
)
return ExtractedFileData(
filename=filename,
content=content,
content_type=content_type,
headers=file_headers,
)
def unpack_defs(schema, defs):
properties = schema.get("properties", None)
if properties is None:
@ -381,3 +452,4 @@ def unpack_defs(schema, defs):
unpack_defs(ref, defs)
value["items"] = ref
continue

View file

@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -301,6 +301,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
messages=messages,
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
litellm_params=litellm_params,
)
config = ProviderConfigManager.get_provider_chat_config(

View file

@ -868,6 +868,7 @@ class AnthropicConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:

View file

@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -262,6 +262,7 @@ class BaseConfig(ABC):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig):
) -> List[OpenAICreateFileRequestOptionalParams]:
pass
def get_complete_url(
def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
data: CreateFileRequest,
):
return self.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params,
)
@abstractmethod
def transform_create_file_request(
@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig):
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> dict:
) -> Union[dict, str, bytes]:
pass
@abstractmethod

View file

@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -831,6 +831,7 @@ class AmazonConverseConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -229,6 +229,7 @@ class BaseLLMAIOHTTPHandler:
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
api_base=api_base,
)
@ -498,6 +499,7 @@ class BaseLLMAIOHTTPHandler:
model=model,
messages=[{"role": "user", "content": "test"}],
optional_params=optional_params,
litellm_params=litellm_params,
api_base=api_base,
)

View file

@ -192,7 +192,7 @@ class AsyncHTTPHandler:
async def post(
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@ -427,7 +427,7 @@ class AsyncHTTPHandler:
self,
url: str,
client: httpx.AsyncClient,
data: Optional[Union[dict, str]] = None, # type: ignore
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@ -527,7 +527,7 @@ class HTTPHandler:
def post(
self,
url: str,
data: Optional[Union[dict, str]] = None,
data: Optional[Union[dict, str, bytes]] = None,
json: Optional[Union[dict, str, List]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@ -573,7 +573,6 @@ class HTTPHandler:
setattr(e, "text", error_text)
setattr(e, "status_code", e.response.status_code)
raise e
except Exception as e:
raise e

View file

@ -247,6 +247,7 @@ class BaseLLMHTTPHandler:
messages=messages,
optional_params=optional_params,
api_base=api_base,
litellm_params=litellm_params,
)
api_base = provider_config.get_complete_url(
@ -625,6 +626,7 @@ class BaseLLMHTTPHandler:
model=model,
messages=[],
optional_params=optional_params,
litellm_params=litellm_params,
)
api_base = provider_config.get_complete_url(
@ -896,6 +898,7 @@ class BaseLLMHTTPHandler:
model=model,
messages=[],
optional_params=optional_params,
litellm_params=litellm_params,
)
if client is None or not isinstance(client, HTTPHandler):
@ -1228,15 +1231,19 @@ class BaseLLMHTTPHandler:
model="",
messages=[],
optional_params={},
litellm_params=litellm_params,
)
api_base = provider_config.get_complete_url(
api_base = provider_config.get_complete_file_url(
api_base=api_base,
api_key=api_key,
model="",
optional_params={},
litellm_params=litellm_params,
data=create_file_data,
)
if api_base is None:
raise ValueError("api_base is required for create_file")
# Get the transformed request data for both steps
transformed_request = provider_config.transform_create_file_request(
@ -1263,48 +1270,57 @@ class BaseLLMHTTPHandler:
else:
sync_httpx_client = client
try:
# Step 1: Initial request to get upload URL
initial_response = sync_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
if isinstance(transformed_request, str) or isinstance(
transformed_request, bytes
):
upload_response = sync_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
url=api_base,
headers=headers,
data=transformed_request,
timeout=timeout,
)
else:
try:
# Step 1: Initial request to get upload URL
initial_response = sync_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
upload_response = sync_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
async def async_create_file(
self,
transformed_request: dict,
transformed_request: Union[bytes, str, dict],
litellm_params: dict,
provider_config: BaseFilesConfig,
headers: dict,
@ -1323,45 +1339,54 @@ class BaseLLMHTTPHandler:
else:
async_httpx_client = client
try:
# Step 1: Initial request to get upload URL
initial_response = await async_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
if isinstance(transformed_request, str) or isinstance(
transformed_request, bytes
):
upload_response = await async_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
url=api_base,
headers=headers,
data=transformed_request,
timeout=timeout,
)
else:
try:
# Step 1: Initial request to get upload URL
initial_response = await async_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
except Exception as e:
verbose_logger.exception(f"Error creating file: {e}")
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
upload_response = await async_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
timeout=timeout,
)
except Exception as e:
verbose_logger.exception(f"Error creating file: {e}")
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
def list_files(self):
"""

View file

@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -41,6 +41,7 @@ class FireworksAIMixin:
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API.
For vertex ai, check out the vertex_ai/files/handler.py file.
"""
import time
from typing import List, Mapping, Optional
from typing import List, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
if file_data is None:
raise ValueError("File data is required")
# Parse the file_data based on its type
filename = None
file_content = None
content_type = None
file_headers: Mapping[str, str] = {}
if isinstance(file_data, tuple):
if len(file_data) == 2:
filename, file_content = file_data
elif len(file_data) == 3:
filename, file_content, content_type = file_data
elif len(file_data) == 4:
filename, file_content, content_type, file_headers = file_data
else:
file_content = file_data
# Handle the file content based on its type
import io
from os import PathLike
# Convert content to bytes
if isinstance(file_content, (str, PathLike)):
# If it's a path, open and read the file
with open(file_content, "rb") as f:
content = f.read()
elif isinstance(file_content, io.IOBase):
# If it's a file-like object
content = file_content.read()
if isinstance(content, str):
content = content.encode("utf-8")
elif isinstance(file_content, bytes):
content = file_content
else:
raise ValueError(f"Unsupported file content type: {type(file_content)}")
# Use the common utility function to extract file data
extracted_data = extract_file_data(file_data)
# Get file size
file_size = len(content)
# Use provided content type or guess based on filename
if not content_type:
import mimetypes
content_type = (
mimetypes.guess_type(filename)[0]
if filename
else "application/octet-stream"
)
file_size = len(extracted_data["content"])
# Step 1: Initial resumable upload request
headers = {
"X-Goog-Upload-Protocol": "resumable",
"X-Goog-Upload-Command": "start",
"X-Goog-Upload-Header-Content-Length": str(file_size),
"X-Goog-Upload-Header-Content-Type": content_type,
"X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
"Content-Type": "application/json",
}
headers.update(file_headers) # Add any custom headers
headers.update(extracted_data["headers"]) # Add any custom headers
# Initial metadata request body
initial_data = {"file": {"display_name": filename or str(int(time.time()))}}
initial_data = {
"file": {
"display_name": extracted_data["filename"] or str(int(time.time()))
}
}
# Step 2: Actual file upload data
upload_headers = {
@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
return {
"initial_request": {"headers": headers, "data": initial_data},
"upload_request": {"headers": upload_headers, "data": content},
"upload_request": {
"headers": upload_headers,
"data": extracted_data["content"],
},
}
def transform_create_file_response(

View file

@ -1,6 +1,6 @@
import logging
import os
from typing import TYPE_CHECKING, Any, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
return HuggingFaceError(
status_code=status_code, message=error_message, headers=headers
)
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
"""
@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
if api_base is not None:
complete_url = api_base
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
complete_url = str(os.getenv("HF_API_BASE")) or str(
os.getenv("HUGGINGFACE_API_BASE")
)
elif model.startswith(("http://", "https://")):
complete_url = model
# 4. Default construction with provider
@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
)
mapped_model = provider_mapping["providerId"]
messages = self._transform_messages(messages=messages, model=mapped_model)
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
return dict(
ChatCompletionRequest(
model=mapped_model, messages=messages, **optional_params
)
)

View file

@ -1,15 +1,6 @@
import json
import os
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Union,
get_args,
)
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
import httpx
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
]
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
return pipeline_tag
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
litellm_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
model=model,
optional_params=optional_params,
messages=[],
litellm_params=litellm_params,
)
task_type = optional_params.pop("input_type", None)
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
embed_url = (
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
)
## ROUTING ##
if aembedding is True:

View file

@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:

View file

@ -36,6 +36,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
## Load Config

View file

@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -32,6 +32,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
if "https" in model:
completion_url = model
@ -123,6 +124,7 @@ def embedding(
model=model,
messages=[],
optional_params=optional_params,
litellm_params={},
)
response = litellm.module_level_client.post(
embeddings_url, headers=headers, json=data

View file

@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -228,10 +228,10 @@ class PredibaseChatCompletion:
api_key: str,
logging_obj,
optional_params: dict,
litellm_params: dict,
tenant_id: str,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]:
@ -241,6 +241,7 @@ class PredibaseChatCompletion:
messages=messages,
optional_params=optional_params,
model=model,
litellm_params=litellm_params,
)
completion_url = ""
input_text = ""

View file

@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -141,6 +141,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
# Start a prediction and get the prediction URL
version_id = replicate_config.model_to_version_id(model)

View file

@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM):
model: str,
data: dict,
messages: List[AllMessageValues],
litellm_params: dict,
optional_params: dict,
aws_region_name: str,
extra_headers: Optional[dict] = None,
@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers
@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM):
data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": _data,
"optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,
@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": data,
"optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,
@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": data,
"optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,

View file

@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:

View file

@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -1,3 +1,4 @@
import asyncio
from typing import Any, Coroutine, Optional, Union
import httpx
@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .transformation import VertexAIFilesTransformation
from .transformation import VertexAIJsonlFilesTransformation
vertex_ai_files_transformation = VertexAIFilesTransformation()
vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
class VertexAIFilesHandler(GCSBucketBase):
@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase):
timeout=timeout,
max_retries=max_retries,
)
return None # type: ignore
else:
return asyncio.run(
self.async_create_file(
create_file_data=create_file_data,
api_base=api_base,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
timeout=timeout,
max_retries=max_retries,
)
)

View file

@ -1,7 +1,17 @@
import json
import os
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from httpx import Headers, Response
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
)
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
from litellm.types.llms.openai import (
AllMessageValues,
CreateFileRequest,
FileTypes,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
PathLike,
)
from litellm.types.llms.vertex_ai import GcsBucketResponse
from litellm.types.utils import ExtractedFileData, LlmProviders
from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
class VertexAIFilesTransformation(VertexGeminiConfig):
class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
"""
Config for VertexAI Files
"""
def __init__(self):
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
super().__init__()
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.VERTEX_AI
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if not api_key:
api_key, _ = self.get_access_token(
credentials=litellm_params.get("vertex_credentials"),
project_id=litellm_params.get("vertex_project"),
)
if not api_key:
raise ValueError("api_key is required")
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
"""
Helper to extract content from various OpenAI file types and return as string.
Handles:
- Direct content (str, bytes, IO[bytes])
- Tuple formats: (filename, content, [content_type], [headers])
- PathLike objects
"""
content: Union[str, bytes] = b""
# Extract file content from tuple if necessary
if isinstance(openai_file_content, tuple):
# Take the second element which is always the file content
file_content = openai_file_content[1]
else:
file_content = openai_file_content
# Handle different file content types
if isinstance(file_content, str):
# String content can be used directly
content = file_content
elif isinstance(file_content, bytes):
# Bytes content can be decoded
content = file_content
elif isinstance(file_content, PathLike): # PathLike
with open(str(file_content), "rb") as f:
content = f.read()
elif hasattr(file_content, "read"): # IO[bytes]
# File-like objects need to be read
content = file_content.read()
# Ensure content is string
if isinstance(content, bytes):
content = content.decode("utf-8")
return content
def _get_gcs_object_name_from_batch_jsonl(
self,
openai_jsonl_content: List[Dict[str, Any]],
) -> str:
"""
Gets a unique GCS object name for the VertexAI batch prediction job
named as: litellm-vertex-{model}-{uuid}
"""
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
if "publishers/google/models" not in _model:
_model = f"publishers/google/models/{_model}"
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
return object_name
def get_object_name(
self, extracted_file_data: ExtractedFileData, purpose: str
) -> str:
"""
Get the object name for the request
"""
extracted_file_data_content = extracted_file_data.get("content")
if extracted_file_data_content is None:
raise ValueError("file content is required")
if purpose == "batch":
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
if len(openai_jsonl_content) > 0:
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
## 2. If not jsonl, return the filename
filename = extracted_file_data.get("filename")
if filename:
return filename
## 3. If no file name, return timestamp
return str(int(time.time()))
def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateFileRequest,
) -> str:
"""
Get the complete url for the request
"""
bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME")
if not bucket_name:
raise ValueError("GCS bucket_name is required")
file_data = data.get("file")
purpose = data.get("purpose")
if file_data is None:
raise ValueError("file is required")
if purpose is None:
raise ValueError("purpose is required")
extracted_file_data = extract_file_data(file_data)
object_name = self.get_object_name(extracted_file_data, purpose)
endpoint = (
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
)
api_base = api_base or "https://storage.googleapis.com"
if not api_base:
raise ValueError("api_base is required")
return f"{api_base}/{endpoint}"
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def _map_openai_to_vertex_params(
self,
openai_request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
wrapper to call VertexGeminiConfig.map_openai_params
"""
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
config = VertexGeminiConfig()
_model = openai_request_body.get("model", "")
vertex_params = config.map_openai_params(
model=_model,
non_default_params=openai_request_body,
optional_params={},
drop_params=False,
)
return vertex_params
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
self, openai_jsonl_content: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Transforms OpenAI JSONL content to VertexAI JSONL content
jsonl body for vertex is {"request": <request_body>}
Example Vertex jsonl
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
"""
vertex_jsonl_content = []
for _openai_jsonl_content in openai_jsonl_content:
openai_request_body = _openai_jsonl_content.get("body") or {}
vertex_request_body = _transform_request_body(
messages=openai_request_body.get("messages", []),
model=openai_request_body.get("model", ""),
optional_params=self._map_openai_to_vertex_params(openai_request_body),
custom_llm_provider="vertex_ai",
litellm_params={},
cached_content=None,
)
vertex_jsonl_content.append({"request": vertex_request_body})
return vertex_jsonl_content
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, dict]:
"""
2 Cases:
1. Handle basic file upload
2. Handle batch file upload (.jsonl)
"""
file_data = create_file_data.get("file")
if file_data is None:
raise ValueError("file is required")
extracted_file_data = extract_file_data(file_data)
extracted_file_data_content = extracted_file_data.get("content")
if (
create_file_data.get("purpose") == "batch"
and extracted_file_data.get("content_type") == "application/jsonl"
and extracted_file_data_content is not None
):
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
vertex_jsonl_content = (
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
openai_jsonl_content
)
)
return json.dumps(vertex_jsonl_content)
elif isinstance(extracted_file_data_content, bytes):
return extracted_file_data_content
else:
raise ValueError("Unsupported file content type")
def transform_create_file_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
"""
Transform VertexAI File upload response into OpenAI-style FileObject
"""
response_json = raw_response.json()
try:
response_object = GcsBucketResponse(**response_json) # type: ignore
except Exception as e:
raise VertexAIError(
status_code=raw_response.status_code,
message=f"Error reading GCS bucket response: {e}",
headers=raw_response.headers,
)
gcs_id = response_object.get("id", "")
# Remove the last numeric ID from the path
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
return OpenAIFileObject(
purpose=response_object.get("purpose", "batch"),
id=f"gs://{gcs_id}",
filename=response_object.get("name", ""),
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response_object.get("timeCreated", "")
),
status="uploaded",
bytes=int(response_object.get("size", 0)),
object="file",
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
return VertexAIError(
status_code=status_code, message=error_message, headers=headers
)
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
"""
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
"""

View file

@ -905,6 +905,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
@ -1022,7 +1023,7 @@ class VertexLLM(VertexBase):
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
litellm_params: dict,
logger_fn=None,
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
@ -1063,6 +1064,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
## LOGGING
@ -1149,6 +1151,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
request_body = await async_transform_request_body(**data) # type: ignore
@ -1322,6 +1325,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
## TRANSFORMATION ##

View file

@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM):
optional_params=optional_params,
api_key=auth_header,
api_base=api_base,
litellm_params=litellm_params,
)
## LOGGING

View file

@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
@ -22,7 +21,7 @@ else:
GoogleCredentialsObject = Any
class VertexBase(BaseLLM):
class VertexBase:
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None

View file

@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
messages=messages,
optional_params=optional_params,
api_key=api_key,
litellm_params=litellm_params,
)
## UPDATE PAYLOAD (optional params)

View file

@ -165,6 +165,7 @@ class IBMWatsonXMixin:
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:

View file

@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "bedrock":
if isinstance(input, str):

View file

@ -498,6 +498,51 @@ class OutputConfig(TypedDict, total=False):
gcsDestination: GcsDestination
class GcsBucketResponse(TypedDict):
"""
TypedDict for GCS bucket upload response
Attributes:
kind: The kind of item this is. For objects, this is always storage#object
id: The ID of the object
selfLink: The link to this object
mediaLink: The link to download the object
name: The name of the object
bucket: The name of the bucket containing this object
generation: The content generation of this object
metageneration: The metadata generation of this object
contentType: The content type of the object
storageClass: The storage class of the object
size: The size of the object in bytes
md5Hash: The MD5 hash of the object
crc32c: The CRC32c checksum of the object
etag: The ETag of the object
timeCreated: The creation time of the object
updated: The last update time of the object
timeStorageClassUpdated: The time the storage class was last updated
timeFinalized: The time the object was finalized
"""
kind: Literal["storage#object"]
id: str
selfLink: str
mediaLink: str
name: str
bucket: str
generation: str
metageneration: str
contentType: str
storageClass: str
size: str
md5Hash: str
crc32c: str
etag: str
timeCreated: str
updated: str
timeStorageClassUpdated: str
timeFinalized: str
class VertexAIBatchPredictionJob(TypedDict):
displayName: str
model: str

View file

@ -2,7 +2,7 @@ import json
import time
import uuid
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union
from aiohttp import FormData
from openai._models import BaseModel as OpenAIObject
@ -2170,3 +2170,20 @@ class CreateCredentialItem(CredentialBase):
if not values.get("credential_values") and not values.get("model_id"):
raise ValueError("Either credential_values or model_id must be set")
return values
class ExtractedFileData(TypedDict):
"""
TypedDict for storing processed file data
Attributes:
filename: Name of the file if provided
content: The file content in bytes
content_type: MIME type of the file
headers: Any additional headers for the file
"""
filename: Optional[str]
content: bytes
content_type: Optional[str]
headers: Mapping[str, str]

View file

@ -6517,6 +6517,10 @@ class ProviderConfigManager:
)
return GoogleAIStudioFilesHandler()
elif LlmProviders.VERTEX_AI == provider:
from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig
return VertexAIFilesConfig()
return None

View file

@ -423,25 +423,35 @@ mock_vertex_batch_response = {
@pytest.mark.asyncio
async def test_avertex_batch_prediction():
with patch(
async def test_avertex_batch_prediction(monkeypatch):
monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local")
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
client = AsyncHTTPHandler()
async def mock_side_effect(*args, **kwargs):
print("args", args, "kwargs", kwargs)
url = kwargs.get("url", "")
if "files" in url:
mock_response.json.return_value = mock_file_response
elif "batch" in url:
mock_response.json.return_value = mock_vertex_batch_response
mock_response.status_code = 200
return mock_response
with patch.object(
client, "post", side_effect=mock_side_effect
) as mock_post, patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
) as mock_post:
) as mock_global_post:
# Configure mock responses
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
# Set up different responses for different API calls
async def mock_side_effect(*args, **kwargs):
url = kwargs.get("url", "")
if "files" in url:
mock_response.json.return_value = mock_file_response
elif "batch" in url:
mock_response.json.return_value = mock_vertex_batch_response
mock_response.status_code = 200
return mock_response
mock_post.side_effect = mock_side_effect
mock_global_post.side_effect = mock_side_effect
# load_vertex_ai_credentials()
litellm.set_verbose = True
@ -455,6 +465,7 @@ async def test_avertex_batch_prediction():
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider="vertex_ai",
client=client
)
print("Response from creating file=", file_obj)

View file

@ -323,7 +323,8 @@ class TestHuggingFace(BaseLLMChatTest):
model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Hello"}],
optional_params={},
api_key="test_api_key"
api_key="test_api_key",
litellm_params={}
)
assert headers["Authorization"] == "Bearer test_api_key"

View file

@ -0,0 +1,2 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}

View file

@ -21,7 +21,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket import (
StandardLoggingPayload,
)
from litellm.types.utils import StandardCallbackDynamicParams
from unittest.mock import patch
verbose_logger.setLevel(logging.DEBUG)
@ -687,3 +687,63 @@ async def test_basic_gcs_logger_with_folder_in_bucket_name():
# clean up
if old_bucket_name is not None:
os.environ["GCS_BUCKET_NAME"] = old_bucket_name
@pytest.mark.skip(reason="This test is flaky on ci/cd")
def test_create_file_e2e():
"""
Asserts 'create_file' is called with the correct arguments
"""
load_vertex_ai_credentials()
test_file_content = b"test audio content"
test_file = ("test.wav", test_file_content, "audio/wav")
from litellm import create_file
response = create_file(
file=test_file,
purpose="user_data",
custom_llm_provider="vertex_ai",
)
print("response", response)
assert response is not None
@pytest.mark.skip(reason="This test is flaky on ci/cd")
def test_create_file_e2e_jsonl():
"""
Asserts 'create_file' is called with the correct arguments
"""
load_vertex_ai_credentials()
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
example_jsonl = [{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}},{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}]
# Create and write to the file
file_path = "example.jsonl"
with open(file_path, "w") as f:
for item in example_jsonl:
f.write(json.dumps(item) + "\n")
# Verify file content
with open(file_path, "r") as f:
content = f.read()
print("File content:", content)
assert len(content) > 0, "File is empty"
from litellm import create_file
with patch.object(client, "post") as mock_create_file:
try:
response = create_file(
file=open(file_path, "rb"),
purpose="user_data",
custom_llm_provider="vertex_ai",
client=client,
)
except Exception as e:
print("error", e)
mock_create_file.assert_called_once()
print(f"kwargs: {mock_create_file.call_args.kwargs}")
assert mock_create_file.call_args.kwargs["data"] is not None and len(mock_create_file.call_args.kwargs["data"]) > 0