mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
VertexAI non-jsonl file storage support (#9781)
* test: add initial e2e test * fix(vertex_ai/files): initial commit adding sync file create support * refactor: initial commit of vertex ai non-jsonl files reaching gcp endpoint * fix(vertex_ai/files/transformation.py): initial working commit of non-jsonl file call reaching backend endpoint * fix(vertex_ai/files/transformation.py): working e2e non-jsonl file upload * test: working e2e jsonl call * test: unit testing for jsonl file creation * fix(vertex_ai/transformation.py): reset file pointer after read allow multiple reads on same file object * fix: fix linting errors * fix: fix ruff linting errors * fix: fix import * fix: fix linting error * fix: fix linting error * fix(vertex_ai/files/transformation.py): fix linting error * test: update test * test: update tests * fix: fix linting errors * fix: fix test * fix: fix linting error
This commit is contained in:
parent
93532e00db
commit
6ba3c4a4f8
64 changed files with 780 additions and 185 deletions
|
@ -110,5 +110,8 @@ def get_litellm_params(
|
|||
"azure_password": kwargs.get("azure_password"),
|
||||
"max_retries": max_retries,
|
||||
"timeout": kwargs.get("timeout"),
|
||||
"bucket_name": kwargs.get("bucket_name"),
|
||||
"vertex_credentials": kwargs.get("vertex_credentials"),
|
||||
"vertex_project": kwargs.get("vertex_project"),
|
||||
}
|
||||
return litellm_params
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
Common utility functions used for translating messages across providers
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Literal, Optional, Union, cast
|
||||
import io
|
||||
import mimetypes
|
||||
from os import PathLike
|
||||
from typing import Dict, List, Literal, Mapping, Optional, Union, cast
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
|
@ -10,7 +13,13 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionFileObject,
|
||||
ChatCompletionUserMessage,
|
||||
)
|
||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
ExtractedFileData,
|
||||
FileTypes,
|
||||
ModelResponse,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
|
||||
content="Please continue.", role="user"
|
||||
|
@ -350,6 +359,68 @@ def update_messages_with_model_file_ids(
|
|||
return messages
|
||||
|
||||
|
||||
def extract_file_data(file_data: FileTypes) -> ExtractedFileData:
|
||||
"""
|
||||
Extracts and processes file data from various input formats.
|
||||
|
||||
Args:
|
||||
file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content
|
||||
|
||||
Returns:
|
||||
ExtractedFileData containing:
|
||||
- filename: Name of the file if provided
|
||||
- content: The file content in bytes
|
||||
- content_type: MIME type of the file
|
||||
- headers: Any additional headers
|
||||
"""
|
||||
# Parse the file_data based on its type
|
||||
filename = None
|
||||
file_content = None
|
||||
content_type = None
|
||||
file_headers: Mapping[str, str] = {}
|
||||
|
||||
if isinstance(file_data, tuple):
|
||||
if len(file_data) == 2:
|
||||
filename, file_content = file_data
|
||||
elif len(file_data) == 3:
|
||||
filename, file_content, content_type = file_data
|
||||
elif len(file_data) == 4:
|
||||
filename, file_content, content_type, file_headers = file_data
|
||||
else:
|
||||
file_content = file_data
|
||||
# Convert content to bytes
|
||||
if isinstance(file_content, (str, PathLike)):
|
||||
# If it's a path, open and read the file
|
||||
with open(file_content, "rb") as f:
|
||||
content = f.read()
|
||||
elif isinstance(file_content, io.IOBase):
|
||||
# If it's a file-like object
|
||||
content = file_content.read()
|
||||
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
# Reset file pointer to beginning
|
||||
file_content.seek(0)
|
||||
elif isinstance(file_content, bytes):
|
||||
content = file_content
|
||||
else:
|
||||
raise ValueError(f"Unsupported file content type: {type(file_content)}")
|
||||
|
||||
# Use provided content type or guess based on filename
|
||||
if not content_type:
|
||||
content_type = (
|
||||
mimetypes.guess_type(filename)[0]
|
||||
if filename
|
||||
else "application/octet-stream"
|
||||
)
|
||||
|
||||
return ExtractedFileData(
|
||||
filename=filename,
|
||||
content=content,
|
||||
content_type=content_type,
|
||||
headers=file_headers,
|
||||
)
|
||||
|
||||
def unpack_defs(schema, defs):
|
||||
properties = schema.get("properties", None)
|
||||
if properties is None:
|
||||
|
@ -381,3 +452,4 @@ def unpack_defs(schema, defs):
|
|||
unpack_defs(ref, defs)
|
||||
value["items"] = ref
|
||||
continue
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -301,6 +301,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
config = ProviderConfigManager.get_provider_chat_config(
|
||||
|
|
|
@ -868,6 +868,7 @@ class AnthropicConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
|
|
@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -262,6 +262,7 @@ class BaseConfig(ABC):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig):
|
|||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
def get_complete_file_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
data: CreateFileRequest,
|
||||
):
|
||||
return self.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_file_request(
|
||||
|
@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig):
|
|||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> dict:
|
||||
) -> Union[dict, str, bytes]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -831,6 +831,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -229,6 +229,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
@ -498,6 +499,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
|
|
@ -192,7 +192,7 @@ class AsyncHTTPHandler:
|
|||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
|
||||
json: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
|
@ -427,7 +427,7 @@ class AsyncHTTPHandler:
|
|||
self,
|
||||
url: str,
|
||||
client: httpx.AsyncClient,
|
||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
|
||||
json: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
|
@ -527,7 +527,7 @@ class HTTPHandler:
|
|||
def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
data: Optional[Union[dict, str, bytes]] = None,
|
||||
json: Optional[Union[dict, str, List]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
|
@ -573,7 +573,6 @@ class HTTPHandler:
|
|||
setattr(e, "text", error_text)
|
||||
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -247,6 +247,7 @@ class BaseLLMHTTPHandler:
|
|||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
|
@ -625,6 +626,7 @@ class BaseLLMHTTPHandler:
|
|||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
|
@ -896,6 +898,7 @@ class BaseLLMHTTPHandler:
|
|||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
|
@ -1228,15 +1231,19 @@ class BaseLLMHTTPHandler:
|
|||
model="",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
api_base = provider_config.get_complete_file_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model="",
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
data=create_file_data,
|
||||
)
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for create_file")
|
||||
|
||||
# Get the transformed request data for both steps
|
||||
transformed_request = provider_config.transform_create_file_request(
|
||||
|
@ -1263,48 +1270,57 @@ class BaseLLMHTTPHandler:
|
|||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
try:
|
||||
# Step 1: Initial request to get upload URL
|
||||
initial_response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers={
|
||||
**headers,
|
||||
**transformed_request["initial_request"]["headers"],
|
||||
},
|
||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract upload URL from response headers
|
||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||
|
||||
if not upload_url:
|
||||
raise ValueError("Failed to get upload URL from initial request")
|
||||
|
||||
# Step 2: Upload the actual file
|
||||
if isinstance(transformed_request, str) or isinstance(
|
||||
transformed_request, bytes
|
||||
):
|
||||
upload_response = sync_httpx_client.post(
|
||||
url=upload_url,
|
||||
headers=transformed_request["upload_request"]["headers"],
|
||||
data=transformed_request["upload_request"]["data"],
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=transformed_request,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Step 1: Initial request to get upload URL
|
||||
initial_response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers={
|
||||
**headers,
|
||||
**transformed_request["initial_request"]["headers"],
|
||||
},
|
||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return provider_config.transform_create_file_response(
|
||||
model=None,
|
||||
raw_response=upload_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
# Extract upload URL from response headers
|
||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
if not upload_url:
|
||||
raise ValueError("Failed to get upload URL from initial request")
|
||||
|
||||
# Step 2: Upload the actual file
|
||||
upload_response = sync_httpx_client.post(
|
||||
url=upload_url,
|
||||
headers=transformed_request["upload_request"]["headers"],
|
||||
data=transformed_request["upload_request"]["data"],
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
return provider_config.transform_create_file_response(
|
||||
model=None,
|
||||
raw_response=upload_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
async def async_create_file(
|
||||
self,
|
||||
transformed_request: dict,
|
||||
transformed_request: Union[bytes, str, dict],
|
||||
litellm_params: dict,
|
||||
provider_config: BaseFilesConfig,
|
||||
headers: dict,
|
||||
|
@ -1323,45 +1339,54 @@ class BaseLLMHTTPHandler:
|
|||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
try:
|
||||
# Step 1: Initial request to get upload URL
|
||||
initial_response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers={
|
||||
**headers,
|
||||
**transformed_request["initial_request"]["headers"],
|
||||
},
|
||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract upload URL from response headers
|
||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||
|
||||
if not upload_url:
|
||||
raise ValueError("Failed to get upload URL from initial request")
|
||||
|
||||
# Step 2: Upload the actual file
|
||||
if isinstance(transformed_request, str) or isinstance(
|
||||
transformed_request, bytes
|
||||
):
|
||||
upload_response = await async_httpx_client.post(
|
||||
url=upload_url,
|
||||
headers=transformed_request["upload_request"]["headers"],
|
||||
data=transformed_request["upload_request"]["data"],
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=transformed_request,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Step 1: Initial request to get upload URL
|
||||
initial_response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers={
|
||||
**headers,
|
||||
**transformed_request["initial_request"]["headers"],
|
||||
},
|
||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return provider_config.transform_create_file_response(
|
||||
model=None,
|
||||
raw_response=upload_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
# Extract upload URL from response headers
|
||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error creating file: {e}")
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
if not upload_url:
|
||||
raise ValueError("Failed to get upload URL from initial request")
|
||||
|
||||
# Step 2: Upload the actual file
|
||||
upload_response = await async_httpx_client.post(
|
||||
url=upload_url,
|
||||
headers=transformed_request["upload_request"]["headers"],
|
||||
data=transformed_request["upload_request"]["data"],
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error creating file: {e}")
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
return provider_config.transform_create_file_response(
|
||||
model=None,
|
||||
raw_response=upload_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
def list_files(self):
|
||||
"""
|
||||
|
|
|
@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -41,6 +41,7 @@ class FireworksAIMixin:
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API.
|
|||
For vertex ai, check out the vertex_ai/files/handler.py file.
|
||||
"""
|
||||
import time
|
||||
from typing import List, Mapping, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.files.transformation import (
|
||||
BaseFilesConfig,
|
||||
LiteLLMLoggingObj,
|
||||
|
@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
|||
if file_data is None:
|
||||
raise ValueError("File data is required")
|
||||
|
||||
# Parse the file_data based on its type
|
||||
filename = None
|
||||
file_content = None
|
||||
content_type = None
|
||||
file_headers: Mapping[str, str] = {}
|
||||
|
||||
if isinstance(file_data, tuple):
|
||||
if len(file_data) == 2:
|
||||
filename, file_content = file_data
|
||||
elif len(file_data) == 3:
|
||||
filename, file_content, content_type = file_data
|
||||
elif len(file_data) == 4:
|
||||
filename, file_content, content_type, file_headers = file_data
|
||||
else:
|
||||
file_content = file_data
|
||||
|
||||
# Handle the file content based on its type
|
||||
import io
|
||||
from os import PathLike
|
||||
|
||||
# Convert content to bytes
|
||||
if isinstance(file_content, (str, PathLike)):
|
||||
# If it's a path, open and read the file
|
||||
with open(file_content, "rb") as f:
|
||||
content = f.read()
|
||||
elif isinstance(file_content, io.IOBase):
|
||||
# If it's a file-like object
|
||||
content = file_content.read()
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
elif isinstance(file_content, bytes):
|
||||
content = file_content
|
||||
else:
|
||||
raise ValueError(f"Unsupported file content type: {type(file_content)}")
|
||||
# Use the common utility function to extract file data
|
||||
extracted_data = extract_file_data(file_data)
|
||||
|
||||
# Get file size
|
||||
file_size = len(content)
|
||||
|
||||
# Use provided content type or guess based on filename
|
||||
if not content_type:
|
||||
import mimetypes
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(filename)[0]
|
||||
if filename
|
||||
else "application/octet-stream"
|
||||
)
|
||||
file_size = len(extracted_data["content"])
|
||||
|
||||
# Step 1: Initial resumable upload request
|
||||
headers = {
|
||||
"X-Goog-Upload-Protocol": "resumable",
|
||||
"X-Goog-Upload-Command": "start",
|
||||
"X-Goog-Upload-Header-Content-Length": str(file_size),
|
||||
"X-Goog-Upload-Header-Content-Type": content_type,
|
||||
"X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers.update(file_headers) # Add any custom headers
|
||||
headers.update(extracted_data["headers"]) # Add any custom headers
|
||||
|
||||
# Initial metadata request body
|
||||
initial_data = {"file": {"display_name": filename or str(int(time.time()))}}
|
||||
initial_data = {
|
||||
"file": {
|
||||
"display_name": extracted_data["filename"] or str(int(time.time()))
|
||||
}
|
||||
}
|
||||
|
||||
# Step 2: Actual file upload data
|
||||
upload_headers = {
|
||||
|
@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
|||
|
||||
return {
|
||||
"initial_request": {"headers": headers, "data": initial_data},
|
||||
"upload_request": {"headers": upload_headers, "data": content},
|
||||
"upload_request": {
|
||||
"headers": upload_headers,
|
||||
"data": extracted_data["content"],
|
||||
},
|
||||
}
|
||||
|
||||
def transform_create_file_response(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://router.huggingface.co"
|
||||
|
@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
|
||||
return HuggingFaceError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
|
@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|||
if api_base is not None:
|
||||
complete_url = api_base
|
||||
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
||||
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
|
||||
complete_url = str(os.getenv("HF_API_BASE")) or str(
|
||||
os.getenv("HUGGINGFACE_API_BASE")
|
||||
)
|
||||
elif model.startswith(("http://", "https://")):
|
||||
complete_url = model
|
||||
# 4. Default construction with provider
|
||||
|
@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|||
)
|
||||
mapped_model = provider_mapping["providerId"]
|
||||
messages = self._transform_messages(messages=messages, model=mapped_model)
|
||||
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
|
||||
return dict(
|
||||
ChatCompletionRequest(
|
||||
model=mapped_model, messages=messages, **optional_params
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,15 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
get_args,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
|
|||
]
|
||||
|
||||
|
||||
|
||||
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
||||
def get_hf_task_embedding_for_model(
|
||||
model: str, task_type: Optional[str], api_base: str
|
||||
) -> Optional[str]:
|
||||
if task_type is not None:
|
||||
if task_type in get_args(hf_tasks_embeddings):
|
||||
return task_type
|
||||
|
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
|
|||
return pipeline_tag
|
||||
|
||||
|
||||
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
||||
async def async_get_hf_task_embedding_for_model(
|
||||
model: str, task_type: Optional[str], api_base: str
|
||||
) -> Optional[str]:
|
||||
if task_type is not None:
|
||||
if task_type in get_args(hf_tasks_embeddings):
|
||||
return task_type
|
||||
|
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
input: List,
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
hf_task = await async_get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
|
||||
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
||||
|
||||
|
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
task_type = optional_params.pop("input_type", None)
|
||||
|
||||
if call_type == "sync":
|
||||
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
hf_task = get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
elif call_type == "async":
|
||||
return self._async_transform_input(
|
||||
model=model, task_type=task_type, embed_url=embed_url, input=input
|
||||
|
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
encoding: Callable,
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
model=model,
|
||||
optional_params=optional_params,
|
||||
messages=[],
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
task_type = optional_params.pop("input_type", None)
|
||||
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
task = get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
# print_verbose(f"{model}, {task}")
|
||||
embed_url = ""
|
||||
if "https" in model:
|
||||
|
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||
else:
|
||||
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||
embed_url = (
|
||||
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||
)
|
||||
|
||||
## ROUTING ##
|
||||
if aembedding is True:
|
||||
|
|
|
@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
|
|
@ -36,6 +36,7 @@ def completion(
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## Load Config
|
||||
|
|
|
@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -32,6 +32,7 @@ def completion(
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if "https" in model:
|
||||
completion_url = model
|
||||
|
@ -123,6 +124,7 @@ def embedding(
|
|||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
)
|
||||
response = litellm.module_level_client.post(
|
||||
embeddings_url, headers=headers, json=data
|
||||
|
|
|
@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -228,10 +228,10 @@ class PredibaseChatCompletion:
|
|||
api_key: str,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
tenant_id: str,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
@ -241,6 +241,7 @@ class PredibaseChatCompletion:
|
|||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
|
|
|
@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -141,6 +141,7 @@ def completion(
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = replicate_config.model_to_version_id(model)
|
||||
|
|
|
@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
model: str,
|
||||
data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
litellm_params: dict,
|
||||
optional_params: dict,
|
||||
aws_region_name: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
|
@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
request = AWSRequest(
|
||||
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||
|
@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
"model": model,
|
||||
"data": _data,
|
||||
"optional_params": optional_params,
|
||||
"litellm_params": litellm_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
"messages": messages,
|
||||
|
@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
"model": model,
|
||||
"data": data,
|
||||
"optional_params": optional_params,
|
||||
"litellm_params": litellm_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
"messages": messages,
|
||||
|
@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
"model": model,
|
||||
"data": data,
|
||||
"optional_params": optional_params,
|
||||
"litellm_params": litellm_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
"messages": messages,
|
||||
|
|
|
@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
|
|
@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
|||
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
from .transformation import VertexAIFilesTransformation
|
||||
from .transformation import VertexAIJsonlFilesTransformation
|
||||
|
||||
vertex_ai_files_transformation = VertexAIFilesTransformation()
|
||||
vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
|
||||
|
||||
|
||||
class VertexAIFilesHandler(GCSBucketBase):
|
||||
|
@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
return None # type: ignore
|
||||
else:
|
||||
return asyncio.run(
|
||||
self.async_create_file(
|
||||
create_file_data=create_file_data,
|
||||
api_base=api_base,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,7 +1,17 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.files.transformation import (
|
||||
BaseFilesConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
|
@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
|||
VertexGeminiConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateFileRequest,
|
||||
FileTypes,
|
||||
OpenAICreateFileRequestOptionalParams,
|
||||
OpenAIFileObject,
|
||||
PathLike,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import GcsBucketResponse
|
||||
from litellm.types.utils import ExtractedFileData, LlmProviders
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
from ..vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexAIFilesTransformation(VertexGeminiConfig):
|
||||
class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
|
||||
"""
|
||||
Config for VertexAI Files
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.VERTEX_AI
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if not api_key:
|
||||
api_key, _ = self.get_access_token(
|
||||
credentials=litellm_params.get("vertex_credentials"),
|
||||
project_id=litellm_params.get("vertex_project"),
|
||||
)
|
||||
if not api_key:
|
||||
raise ValueError("api_key is required")
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||
"""
|
||||
Helper to extract content from various OpenAI file types and return as string.
|
||||
|
||||
Handles:
|
||||
- Direct content (str, bytes, IO[bytes])
|
||||
- Tuple formats: (filename, content, [content_type], [headers])
|
||||
- PathLike objects
|
||||
"""
|
||||
content: Union[str, bytes] = b""
|
||||
# Extract file content from tuple if necessary
|
||||
if isinstance(openai_file_content, tuple):
|
||||
# Take the second element which is always the file content
|
||||
file_content = openai_file_content[1]
|
||||
else:
|
||||
file_content = openai_file_content
|
||||
|
||||
# Handle different file content types
|
||||
if isinstance(file_content, str):
|
||||
# String content can be used directly
|
||||
content = file_content
|
||||
elif isinstance(file_content, bytes):
|
||||
# Bytes content can be decoded
|
||||
content = file_content
|
||||
elif isinstance(file_content, PathLike): # PathLike
|
||||
with open(str(file_content), "rb") as f:
|
||||
content = f.read()
|
||||
elif hasattr(file_content, "read"): # IO[bytes]
|
||||
# File-like objects need to be read
|
||||
content = file_content.read()
|
||||
|
||||
# Ensure content is string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return content
|
||||
|
||||
def _get_gcs_object_name_from_batch_jsonl(
|
||||
self,
|
||||
openai_jsonl_content: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
Gets a unique GCS object name for the VertexAI batch prediction job
|
||||
|
||||
named as: litellm-vertex-{model}-{uuid}
|
||||
"""
|
||||
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||
if "publishers/google/models" not in _model:
|
||||
_model = f"publishers/google/models/{_model}"
|
||||
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
|
||||
return object_name
|
||||
|
||||
def get_object_name(
|
||||
self, extracted_file_data: ExtractedFileData, purpose: str
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name for the request
|
||||
"""
|
||||
extracted_file_data_content = extracted_file_data.get("content")
|
||||
|
||||
if extracted_file_data_content is None:
|
||||
raise ValueError("file content is required")
|
||||
|
||||
if purpose == "batch":
|
||||
## 1. If jsonl, check if there's a model name
|
||||
file_content = self._get_content_from_openai_file(
|
||||
extracted_file_data_content
|
||||
)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
if len(openai_jsonl_content) > 0:
|
||||
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
|
||||
|
||||
## 2. If not jsonl, return the filename
|
||||
filename = extracted_file_data.get("filename")
|
||||
if filename:
|
||||
return filename
|
||||
## 3. If no file name, return timestamp
|
||||
return str(int(time.time()))
|
||||
|
||||
def get_complete_file_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
data: CreateFileRequest,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME")
|
||||
if not bucket_name:
|
||||
raise ValueError("GCS bucket_name is required")
|
||||
file_data = data.get("file")
|
||||
purpose = data.get("purpose")
|
||||
if file_data is None:
|
||||
raise ValueError("file is required")
|
||||
if purpose is None:
|
||||
raise ValueError("purpose is required")
|
||||
extracted_file_data = extract_file_data(file_data)
|
||||
object_name = self.get_object_name(extracted_file_data, purpose)
|
||||
endpoint = (
|
||||
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
)
|
||||
api_base = api_base or "https://storage.googleapis.com"
|
||||
if not api_base:
|
||||
raise ValueError("api_base is required")
|
||||
|
||||
return f"{api_base}/{endpoint}"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def _map_openai_to_vertex_params(
|
||||
self,
|
||||
openai_request_body: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
wrapper to call VertexGeminiConfig.map_openai_params
|
||||
"""
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
)
|
||||
|
||||
config = VertexGeminiConfig()
|
||||
_model = openai_request_body.get("model", "")
|
||||
vertex_params = config.map_openai_params(
|
||||
model=_model,
|
||||
non_default_params=openai_request_body,
|
||||
optional_params={},
|
||||
drop_params=False,
|
||||
)
|
||||
return vertex_params
|
||||
|
||||
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transforms OpenAI JSONL content to VertexAI JSONL content
|
||||
|
||||
jsonl body for vertex is {"request": <request_body>}
|
||||
Example Vertex jsonl
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
|
||||
"""
|
||||
|
||||
vertex_jsonl_content = []
|
||||
for _openai_jsonl_content in openai_jsonl_content:
|
||||
openai_request_body = _openai_jsonl_content.get("body") or {}
|
||||
vertex_request_body = _transform_request_body(
|
||||
messages=openai_request_body.get("messages", []),
|
||||
model=openai_request_body.get("model", ""),
|
||||
optional_params=self._map_openai_to_vertex_params(openai_request_body),
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params={},
|
||||
cached_content=None,
|
||||
)
|
||||
vertex_jsonl_content.append({"request": vertex_request_body})
|
||||
return vertex_jsonl_content
|
||||
|
||||
def transform_create_file_request(
|
||||
self,
|
||||
model: str,
|
||||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[bytes, str, dict]:
|
||||
"""
|
||||
2 Cases:
|
||||
1. Handle basic file upload
|
||||
2. Handle batch file upload (.jsonl)
|
||||
"""
|
||||
file_data = create_file_data.get("file")
|
||||
if file_data is None:
|
||||
raise ValueError("file is required")
|
||||
extracted_file_data = extract_file_data(file_data)
|
||||
extracted_file_data_content = extracted_file_data.get("content")
|
||||
if (
|
||||
create_file_data.get("purpose") == "batch"
|
||||
and extracted_file_data.get("content_type") == "application/jsonl"
|
||||
and extracted_file_data_content is not None
|
||||
):
|
||||
## 1. If jsonl, check if there's a model name
|
||||
file_content = self._get_content_from_openai_file(
|
||||
extracted_file_data_content
|
||||
)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
vertex_jsonl_content = (
|
||||
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
openai_jsonl_content
|
||||
)
|
||||
)
|
||||
return json.dumps(vertex_jsonl_content)
|
||||
elif isinstance(extracted_file_data_content, bytes):
|
||||
return extracted_file_data_content
|
||||
else:
|
||||
raise ValueError("Unsupported file content type")
|
||||
|
||||
def transform_create_file_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transform VertexAI File upload response into OpenAI-style FileObject
|
||||
"""
|
||||
response_json = raw_response.json()
|
||||
|
||||
try:
|
||||
response_object = GcsBucketResponse(**response_json) # type: ignore
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Error reading GCS bucket response: {e}",
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
gcs_id = response_object.get("id", "")
|
||||
# Remove the last numeric ID from the path
|
||||
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||
|
||||
return OpenAIFileObject(
|
||||
purpose=response_object.get("purpose", "batch"),
|
||||
id=f"gs://{gcs_id}",
|
||||
filename=response_object.get("name", ""),
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=response_object.get("timeCreated", "")
|
||||
),
|
||||
status="uploaded",
|
||||
bytes=int(response_object.get("size", 0)),
|
||||
object="file",
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return VertexAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
|
||||
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
|
||||
"""
|
||||
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
||||
"""
|
||||
|
|
|
@ -905,6 +905,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
@ -1022,7 +1023,7 @@ class VertexLLM(VertexBase):
|
|||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
|
@ -1063,6 +1064,7 @@ class VertexLLM(VertexBase):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
@ -1149,6 +1151,7 @@ class VertexLLM(VertexBase):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
request_body = await async_transform_request_body(**data) # type: ignore
|
||||
|
@ -1322,6 +1325,7 @@ class VertexLLM(VertexBase):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
|
|
|
@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
|||
optional_params=optional_params,
|
||||
api_key=auth_header,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
|
|
@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
|||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
|
@ -22,7 +21,7 @@ else:
|
|||
GoogleCredentialsObject = Any
|
||||
|
||||
|
||||
class VertexBase(BaseLLM):
|
||||
class VertexBase:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.access_token: Optional[str] = None
|
||||
|
|
|
@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
|||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## UPDATE PAYLOAD (optional params)
|
||||
|
|
|
@ -165,6 +165,7 @@ class IBMWatsonXMixin:
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
|
|
@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915
|
|||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if isinstance(input, str):
|
||||
|
|
|
@ -498,6 +498,51 @@ class OutputConfig(TypedDict, total=False):
|
|||
gcsDestination: GcsDestination
|
||||
|
||||
|
||||
class GcsBucketResponse(TypedDict):
|
||||
"""
|
||||
TypedDict for GCS bucket upload response
|
||||
|
||||
Attributes:
|
||||
kind: The kind of item this is. For objects, this is always storage#object
|
||||
id: The ID of the object
|
||||
selfLink: The link to this object
|
||||
mediaLink: The link to download the object
|
||||
name: The name of the object
|
||||
bucket: The name of the bucket containing this object
|
||||
generation: The content generation of this object
|
||||
metageneration: The metadata generation of this object
|
||||
contentType: The content type of the object
|
||||
storageClass: The storage class of the object
|
||||
size: The size of the object in bytes
|
||||
md5Hash: The MD5 hash of the object
|
||||
crc32c: The CRC32c checksum of the object
|
||||
etag: The ETag of the object
|
||||
timeCreated: The creation time of the object
|
||||
updated: The last update time of the object
|
||||
timeStorageClassUpdated: The time the storage class was last updated
|
||||
timeFinalized: The time the object was finalized
|
||||
"""
|
||||
|
||||
kind: Literal["storage#object"]
|
||||
id: str
|
||||
selfLink: str
|
||||
mediaLink: str
|
||||
name: str
|
||||
bucket: str
|
||||
generation: str
|
||||
metageneration: str
|
||||
contentType: str
|
||||
storageClass: str
|
||||
size: str
|
||||
md5Hash: str
|
||||
crc32c: str
|
||||
etag: str
|
||||
timeCreated: str
|
||||
updated: str
|
||||
timeStorageClassUpdated: str
|
||||
timeFinalized: str
|
||||
|
||||
|
||||
class VertexAIBatchPredictionJob(TypedDict):
|
||||
displayName: str
|
||||
model: str
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union
|
||||
|
||||
from aiohttp import FormData
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
|
@ -2170,3 +2170,20 @@ class CreateCredentialItem(CredentialBase):
|
|||
if not values.get("credential_values") and not values.get("model_id"):
|
||||
raise ValueError("Either credential_values or model_id must be set")
|
||||
return values
|
||||
|
||||
|
||||
class ExtractedFileData(TypedDict):
|
||||
"""
|
||||
TypedDict for storing processed file data
|
||||
|
||||
Attributes:
|
||||
filename: Name of the file if provided
|
||||
content: The file content in bytes
|
||||
content_type: MIME type of the file
|
||||
headers: Any additional headers for the file
|
||||
"""
|
||||
|
||||
filename: Optional[str]
|
||||
content: bytes
|
||||
content_type: Optional[str]
|
||||
headers: Mapping[str, str]
|
||||
|
|
|
@ -6517,6 +6517,10 @@ class ProviderConfigManager:
|
|||
)
|
||||
|
||||
return GoogleAIStudioFilesHandler()
|
||||
elif LlmProviders.VERTEX_AI == provider:
|
||||
from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig
|
||||
|
||||
return VertexAIFilesConfig()
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -423,25 +423,35 @@ mock_vertex_batch_response = {
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_avertex_batch_prediction():
|
||||
with patch(
|
||||
async def test_avertex_batch_prediction(monkeypatch):
|
||||
monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local")
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
async def mock_side_effect(*args, **kwargs):
|
||||
print("args", args, "kwargs", kwargs)
|
||||
url = kwargs.get("url", "")
|
||||
if "files" in url:
|
||||
mock_response.json.return_value = mock_file_response
|
||||
elif "batch" in url:
|
||||
mock_response.json.return_value = mock_vertex_batch_response
|
||||
mock_response.status_code = 200
|
||||
return mock_response
|
||||
|
||||
with patch.object(
|
||||
client, "post", side_effect=mock_side_effect
|
||||
) as mock_post, patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
|
||||
) as mock_post:
|
||||
) as mock_global_post:
|
||||
# Configure mock responses
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
# Set up different responses for different API calls
|
||||
async def mock_side_effect(*args, **kwargs):
|
||||
url = kwargs.get("url", "")
|
||||
if "files" in url:
|
||||
mock_response.json.return_value = mock_file_response
|
||||
elif "batch" in url:
|
||||
mock_response.json.return_value = mock_vertex_batch_response
|
||||
mock_response.status_code = 200
|
||||
return mock_response
|
||||
|
||||
|
||||
mock_post.side_effect = mock_side_effect
|
||||
mock_global_post.side_effect = mock_side_effect
|
||||
|
||||
# load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
@ -455,6 +465,7 @@ async def test_avertex_batch_prediction():
|
|||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider="vertex_ai",
|
||||
client=client
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
|
|
|
@ -323,7 +323,8 @@ class TestHuggingFace(BaseLLMChatTest):
|
|||
model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
optional_params={},
|
||||
api_key="test_api_key"
|
||||
api_key="test_api_key",
|
||||
litellm_params={}
|
||||
)
|
||||
|
||||
assert headers["Authorization"] == "Bearer test_api_key"
|
||||
|
|
2
tests/local_testing/example.jsonl
Normal file
2
tests/local_testing/example.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
|||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}
|
|
@ -21,7 +21,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket import (
|
|||
StandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
from unittest.mock import patch
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
|
@ -687,3 +687,63 @@ async def test_basic_gcs_logger_with_folder_in_bucket_name():
|
|||
# clean up
|
||||
if old_bucket_name is not None:
|
||||
os.environ["GCS_BUCKET_NAME"] = old_bucket_name
|
||||
|
||||
@pytest.mark.skip(reason="This test is flaky on ci/cd")
|
||||
def test_create_file_e2e():
|
||||
"""
|
||||
Asserts 'create_file' is called with the correct arguments
|
||||
"""
|
||||
load_vertex_ai_credentials()
|
||||
test_file_content = b"test audio content"
|
||||
test_file = ("test.wav", test_file_content, "audio/wav")
|
||||
|
||||
from litellm import create_file
|
||||
response = create_file(
|
||||
file=test_file,
|
||||
purpose="user_data",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
print("response", response)
|
||||
assert response is not None
|
||||
|
||||
@pytest.mark.skip(reason="This test is flaky on ci/cd")
|
||||
def test_create_file_e2e_jsonl():
|
||||
"""
|
||||
Asserts 'create_file' is called with the correct arguments
|
||||
"""
|
||||
load_vertex_ai_credentials()
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
example_jsonl = [{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}},{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}]
|
||||
|
||||
# Create and write to the file
|
||||
file_path = "example.jsonl"
|
||||
with open(file_path, "w") as f:
|
||||
for item in example_jsonl:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
# Verify file content
|
||||
with open(file_path, "r") as f:
|
||||
content = f.read()
|
||||
print("File content:", content)
|
||||
assert len(content) > 0, "File is empty"
|
||||
|
||||
from litellm import create_file
|
||||
with patch.object(client, "post") as mock_create_file:
|
||||
try:
|
||||
response = create_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="user_data",
|
||||
custom_llm_provider="vertex_ai",
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print("error", e)
|
||||
|
||||
mock_create_file.assert_called_once()
|
||||
|
||||
print(f"kwargs: {mock_create_file.call_args.kwargs}")
|
||||
|
||||
assert mock_create_file.call_args.kwargs["data"] is not None and len(mock_create_file.call_args.kwargs["data"]) > 0
|
Loading…
Add table
Add a link
Reference in a new issue