Add Google AI Studio /v1/files upload API support (#9645)

* test: fix import for test

* fix: fix bad error string

* docs: cleanup files docs

* fix(files/main.py): cleanup error string

* style: initial commit with a provider/config pattern for files api

google ai studio files api onboarding

* fix: test

* feat(gemini/files/transformation.py): support gemini files api response transformation

* fix(gemini/files/transformation.py): return file id as gemini uri

allows id to be passed in to chat completion request, just like openai

* feat(llm_http_handler.py): support async route for files api on llm_http_handler

* fix: fix linting errors

* fix: fix model info check

* fix: fix ruff errors

* fix: fix linting errors

* Revert "fix: fix linting errors"

This reverts commit 926a5a527f.

* fix: fix linting errors

* test: fix test

* test: fix tests
This commit is contained in:
Krish Dholakia 2025-04-02 08:56:58 -07:00 committed by GitHub
parent d1abb9b68b
commit 0519c0c507
40 changed files with 1006 additions and 245 deletions

View file

@ -14,48 +14,105 @@ Files are used to upload documents that can be used with features like Assistant
- Delete File - Delete File
- Get File Content - Get File Content
<Tabs> <Tabs>
<TabItem value="proxy" label="LiteLLM PROXY Server"> <TabItem value="proxy" label="LiteLLM PROXY Server">
```bash ### 1. Setup config.yaml
$ export OPENAI_API_KEY="sk-..."
$ litellm ```
# for /files endpoints
# RUNNING on http://0.0.0.0:4000 files_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_version: "2023-03-15-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
``` ```
**Upload a File** ### 2. Start LiteLLM PROXY Server
```bash ```bash
curl http://localhost:4000/v1/files \ litellm --config /path/to/config.yaml
-H "Authorization: Bearer sk-1234" \
-F purpose="fine-tune" \ ## RUNNING on http://0.0.0.0:4000
-F file="@mydata.jsonl"
``` ```
**List Files** ### 3. Use OpenAI's /files endpoints
```bash
curl http://localhost:4000/v1/files \ Upload a File
-H "Authorization: Bearer sk-1234"
```python
from openai import OpenAI
client = OpenAI(
api_key="sk-...",
base_url="http://0.0.0.0:4000/v1"
)
client.files.create(
file=wav_data,
purpose="user_data",
extra_body={"custom_llm_provider": "openai"}
)
``` ```
**Retrieve File Information** List Files
```bash
curl http://localhost:4000/v1/files/file-abc123 \ ```python
-H "Authorization: Bearer sk-1234" from openai import OpenAI
client = OpenAI(
api_key="sk-...",
base_url="http://0.0.0.0:4000/v1"
)
files = client.files.list(extra_body={"custom_llm_provider": "openai"})
print("files=", files)
``` ```
**Delete File** Retrieve File Information
```bash
curl http://localhost:4000/v1/files/file-abc123 \ ```python
-X DELETE \ from openai import OpenAI
-H "Authorization: Bearer sk-1234"
client = OpenAI(
api_key="sk-...",
base_url="http://0.0.0.0:4000/v1"
)
file = client.files.retrieve(file_id="file-abc123", extra_body={"custom_llm_provider": "openai"})
print("file=", file)
``` ```
**Get File Content** Delete File
```bash
curl http://localhost:4000/v1/files/file-abc123/content \ ```python
-H "Authorization: Bearer sk-1234" from openai import OpenAI
client = OpenAI(
api_key="sk-...",
base_url="http://0.0.0.0:4000/v1"
)
response = client.files.delete(file_id="file-abc123", extra_body={"custom_llm_provider": "openai"})
print("delete response=", response)
```
Get File Content
```python
from openai import OpenAI
client = OpenAI(
api_key="sk-...",
base_url="http://0.0.0.0:4000/v1"
)
content = client.files.content(file_id="file-abc123", extra_body={"custom_llm_provider": "openai"})
print("content=", content)
``` ```
</TabItem> </TabItem>
@ -120,7 +177,7 @@ print("file content=", content)
### [OpenAI](#quick-start) ### [OpenAI](#quick-start)
## [Azure OpenAI](./providers/azure#azure-batches-api) ### [Azure OpenAI](./providers/azure#azure-batches-api)
### [Vertex AI](./providers/vertex#batch-apis) ### [Vertex AI](./providers/vertex#batch-apis)

View file

@ -15,7 +15,9 @@ import httpx
import litellm import litellm
from litellm import get_secret_str from litellm import get_secret_str
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI
from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
@ -23,9 +25,18 @@ from litellm.types.llms.openai import (
FileContentRequest, FileContentRequest,
FileTypes, FileTypes,
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
OpenAIFileObject,
) )
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import get_litellm_params, supports_httpx_timeout from litellm.types.utils import LlmProviders
from litellm.utils import (
ProviderConfigManager,
client,
get_litellm_params,
supports_httpx_timeout,
)
base_llm_http_handler = BaseLLMHTTPHandler()
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI() openai_files_instance = OpenAIFilesAPI()
@ -34,6 +45,224 @@ vertex_ai_files_instance = VertexAIFilesHandler()
################################################# #################################################
@client
async def acreate_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> OpenAIFileObject:
"""
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_file,
file,
purpose,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def create_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
"""
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
_is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
logging_obj = cast(
Optional[LiteLLMLoggingObj], kwargs.get("litellm_logging_obj")
)
if logging_obj is None:
raise ValueError("logging_obj is required")
client = kwargs.get("client")
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
extra_headers=extra_headers,
extra_body=extra_body,
)
provider_config = ProviderConfigManager.get_provider_files_config(
model="",
provider=LlmProviders(custom_llm_provider),
)
if provider_config is not None:
response = base_llm_http_handler.create_file(
provider_config=provider_config,
litellm_params=litellm_params_dict,
create_file_data=_create_file_request,
headers=extra_headers or {},
api_base=optional_params.api_base,
api_key=optional_params.api_key,
logging_obj=logging_obj,
_is_async=_is_async,
client=client
if client is not None
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
else None,
timeout=timeout,
)
elif custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
create_file_data=_create_file_request,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_file'. Only ['openai', 'azure', 'vertex_ai'] are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_file", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def afile_retrieve( async def afile_retrieve(
file_id: str, file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai", custom_llm_provider: Literal["openai", "azure"] = "openai",
@ -488,195 +717,6 @@ def file_list(
raise e raise e
async def acreate_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FileObject:
"""
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_file,
file,
purpose,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def create_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
"""
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
_is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
extra_headers=extra_headers,
extra_body=extra_body,
)
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
create_file_data=_create_file_request,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def afile_content( async def afile_content(
file_id: str, file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai", custom_llm_provider: Literal["openai", "azure"] = "openai",

View file

@ -27,6 +27,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -45,7 +45,7 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: ) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
openai_client: Optional[ openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI] Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client( ] = self.get_azure_openai_client(
@ -69,8 +69,8 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
return self.acreate_file( # type: ignore return self.acreate_file( # type: ignore
create_file_data=create_file_data, openai_client=openai_client create_file_data=create_file_data, openai_client=openai_client
) )
response = openai_client.files.create(**create_file_data) response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)
return response return OpenAIFileObject(**response.model_dump())
async def afile_content( async def afile_content(
self, self,

View file

@ -65,6 +65,7 @@ class AzureAIStudioConfig(OpenAIConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -28,6 +28,7 @@ class BaseAudioTranscriptionConfig(BaseConfig, ABC):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -294,6 +294,7 @@ class BaseConfig(ABC):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -29,6 +29,7 @@ class BaseTextCompletionConfig(BaseConfig, ABC):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -43,6 +43,7 @@ class BaseEmbeddingConfig(BaseConfig, ABC):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -0,0 +1,102 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.types.llms.openai import (
AllMessageValues,
CreateFileRequest,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
)
from litellm.types.utils import LlmProviders, ModelResponse
from ..chat.transformation import BaseConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseFilesConfig(BaseConfig):
@property
@abstractmethod
def custom_llm_provider(self) -> LlmProviders:
pass
@abstractmethod
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
pass
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
@abstractmethod
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> dict:
pass
@abstractmethod
def transform_create_file_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
pass
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
)

View file

@ -34,6 +34,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -74,6 +74,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -77,6 +77,7 @@ class CloudflareChatConfig(BaseConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -230,6 +230,7 @@ class BaseLLMAIOHTTPHandler:
api_base = provider_config.get_complete_url( api_base = provider_config.get_complete_url(
api_base=api_base, api_base=api_base,
api_key=api_key,
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
@ -480,6 +481,7 @@ class BaseLLMAIOHTTPHandler:
api_base = provider_config.get_complete_url( api_base = provider_config.get_complete_url(
api_base=api_base, api_base=api_base,
api_key=api_key,
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
@ -519,7 +521,6 @@ class BaseLLMAIOHTTPHandler:
data=data, data=data,
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
api_key=api_key,
logging_obj=logging_obj, logging_obj=logging_obj,
model=model, model=model,
timeout=timeout, timeout=timeout,

View file

@ -7,11 +7,13 @@ import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.types import litellm.types
import litellm.types.utils import litellm.types.utils
from litellm._logging import verbose_logger
from litellm.llms.base_llm.audio_transcription.transformation import ( from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig, BaseAudioTranscriptionConfig,
) )
from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.files.transformation import BaseFilesConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
@ -26,7 +28,12 @@ from litellm.responses.streaming_iterator import (
ResponsesAPIStreamingIterator, ResponsesAPIStreamingIterator,
SyncResponsesAPIStreamingIterator, SyncResponsesAPIStreamingIterator,
) )
from litellm.types.llms.openai import ResponseInputParam, ResponsesAPIResponse from litellm.types.llms.openai import (
CreateFileRequest,
OpenAIFileObject,
ResponseInputParam,
ResponsesAPIResponse,
)
from litellm.types.rerank import OptionalRerankParams, RerankResponse from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse
@ -240,6 +247,7 @@ class BaseLLMHTTPHandler:
api_base = provider_config.get_complete_url( api_base = provider_config.get_complete_url(
api_base=api_base, api_base=api_base,
api_key=api_key,
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
stream=stream, stream=stream,
@ -611,6 +619,7 @@ class BaseLLMHTTPHandler:
api_base = provider_config.get_complete_url( api_base = provider_config.get_complete_url(
api_base=api_base, api_base=api_base,
api_key=api_key,
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
@ -884,6 +893,7 @@ class BaseLLMHTTPHandler:
complete_url = provider_config.get_complete_url( complete_url = provider_config.get_complete_url(
api_base=api_base, api_base=api_base,
api_key=api_key,
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
@ -1185,6 +1195,188 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj, logging_obj=logging_obj,
) )
def create_file(
self,
create_file_data: CreateFileRequest,
litellm_params: dict,
provider_config: BaseFilesConfig,
headers: dict,
api_base: Optional[str],
api_key: Optional[str],
logging_obj: LiteLLMLoggingObj,
_is_async: bool = False,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
"""
Creates a file using Gemini's two-step upload process
"""
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model="",
messages=[],
optional_params={},
)
api_base = provider_config.get_complete_url(
api_base=api_base,
api_key=api_key,
model="",
optional_params={},
litellm_params=litellm_params,
)
# Get the transformed request data for both steps
transformed_request = provider_config.transform_create_file_request(
model="",
create_file_data=create_file_data,
litellm_params=litellm_params,
optional_params={},
)
if _is_async:
return self.async_create_file(
transformed_request=transformed_request,
litellm_params=litellm_params,
provider_config=provider_config,
headers=headers,
api_base=api_base,
logging_obj=logging_obj,
client=client,
timeout=timeout,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
# Step 1: Initial request to get upload URL
initial_response = sync_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
upload_response = sync_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
timeout=timeout,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
async def async_create_file(
self,
transformed_request: dict,
litellm_params: dict,
provider_config: BaseFilesConfig,
headers: dict,
api_base: str,
logging_obj: LiteLLMLoggingObj,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
"""
Creates a file using Gemini's two-step upload process
"""
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=provider_config.custom_llm_provider
)
else:
async_httpx_client = client
try:
# Step 1: Initial request to get upload URL
initial_response = await async_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
upload_response = await async_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
timeout=timeout,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
except Exception as e:
verbose_logger.exception(f"Error creating file: {e}")
raise self._handle_error(
e=e,
provider_config=provider_config,
)
def list_files(self):
"""
Lists all files
"""
pass
def delete_file(self):
"""
Deletes a file
"""
pass
def retrieve_file(self):
"""
Returns the metadata of the file
"""
pass
def retrieve_file_content(self):
"""
Returns the content of the file
"""
pass
def _prepare_fake_stream_request( def _prepare_fake_stream_request(
self, self,
stream: bool, stream: bool,

View file

@ -151,6 +151,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -37,6 +37,7 @@ class DeepSeekChatConfig(OpenAIGPTConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -1,17 +1,41 @@
from typing import List, Optional from typing import List, Optional, Union
import httpx
import litellm import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
class GeminiError(BaseLLMException):
pass
class GeminiModelInfo(BaseLLMModelInfo): class GeminiModelInfo(BaseLLMModelInfo):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""Google AI Studio sends api key in query params"""
return headers
@property
def api_version(self) -> str:
return "v1beta"
@staticmethod @staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]: def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return ( return (
api_base api_base
or get_secret_str("GEMINI_API_BASE") or get_secret_str("GEMINI_API_BASE")
or "https://generativelanguage.googleapis.com/v1beta" or "https://generativelanguage.googleapis.com"
) )
@staticmethod @staticmethod
@ -27,13 +51,14 @@ class GeminiModelInfo(BaseLLMModelInfo):
) -> List[str]: ) -> List[str]:
api_base = GeminiModelInfo.get_api_base(api_base) api_base = GeminiModelInfo.get_api_base(api_base)
api_key = GeminiModelInfo.get_api_key(api_key) api_key = GeminiModelInfo.get_api_key(api_key)
endpoint = f"/{self.api_version}/models"
if api_base is None or api_key is None: if api_base is None or api_key is None:
raise ValueError( raise ValueError(
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint." "GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
) )
response = litellm.module_level_client.get( response = litellm.module_level_client.get(
url=f"{api_base}/models?key={api_key}", url=f"{api_base}{endpoint}?key={api_key}",
) )
if response.status_code != 200: if response.status_code != 200:
@ -49,3 +74,10 @@ class GeminiModelInfo(BaseLLMModelInfo):
litellm_model_name = "gemini/" + stripped_model_name litellm_model_name = "gemini/" + stripped_model_name
litellm_model_names.append(litellm_model_name) litellm_model_names.append(litellm_model_name)
return litellm_model_names return litellm_model_names
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return GeminiError(
status_code=status_code, message=error_message, headers=headers
)

View file

@ -0,0 +1,207 @@
"""
Supports writing files to Google AI Studio Files API.
For vertex ai, check out the vertex_ai/files/handler.py file.
"""
import time
from typing import List, Mapping, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.gemini import GeminiCreateFilesResponseObject
from litellm.types.llms.openai import (
CreateFileRequest,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
)
from litellm.types.utils import LlmProviders
from ..common_utils import GeminiModelInfo
class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
def __init__(self):
pass
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.GEMINI
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
endpoint = "upload/v1beta/files"
api_base = self.get_api_base(api_base)
if not api_base:
raise ValueError("api_base is required")
if not api_key:
raise ValueError("api_key is required")
url = "{}/{}?key={}".format(api_base, endpoint, api_key)
return url
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> dict:
"""
Transform the OpenAI-style file creation request into Gemini's format
Returns:
dict: Contains both request data and headers for the two-step upload
"""
# Extract the file information
file_data = create_file_data.get("file")
if file_data is None:
raise ValueError("File data is required")
# Parse the file_data based on its type
filename = None
file_content = None
content_type = None
file_headers: Mapping[str, str] = {}
if isinstance(file_data, tuple):
if len(file_data) == 2:
filename, file_content = file_data
elif len(file_data) == 3:
filename, file_content, content_type = file_data
elif len(file_data) == 4:
filename, file_content, content_type, file_headers = file_data
else:
file_content = file_data
# Handle the file content based on its type
import io
from os import PathLike
# Convert content to bytes
if isinstance(file_content, (str, PathLike)):
# If it's a path, open and read the file
with open(file_content, "rb") as f:
content = f.read()
elif isinstance(file_content, io.IOBase):
# If it's a file-like object
content = file_content.read()
if isinstance(content, str):
content = content.encode("utf-8")
elif isinstance(file_content, bytes):
content = file_content
else:
raise ValueError(f"Unsupported file content type: {type(file_content)}")
# Get file size
file_size = len(content)
# Use provided content type or guess based on filename
if not content_type:
import mimetypes
content_type = (
mimetypes.guess_type(filename)[0]
if filename
else "application/octet-stream"
)
# Step 1: Initial resumable upload request
headers = {
"X-Goog-Upload-Protocol": "resumable",
"X-Goog-Upload-Command": "start",
"X-Goog-Upload-Header-Content-Length": str(file_size),
"X-Goog-Upload-Header-Content-Type": content_type,
"Content-Type": "application/json",
}
headers.update(file_headers) # Add any custom headers
# Initial metadata request body
initial_data = {"file": {"display_name": filename or str(int(time.time()))}}
# Step 2: Actual file upload data
upload_headers = {
"Content-Length": str(file_size),
"X-Goog-Upload-Offset": "0",
"X-Goog-Upload-Command": "upload, finalize",
}
return {
"initial_request": {"headers": headers, "data": initial_data},
"upload_request": {"headers": upload_headers, "data": content},
}
def transform_create_file_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
"""
Transform Gemini's file upload response into OpenAI-style FileObject
"""
try:
response_json = raw_response.json()
response_object = GeminiCreateFilesResponseObject(
**response_json.get("file", {}) # type: ignore
)
# Extract file information from Gemini response
return OpenAIFileObject(
id=response_object["uri"], # Gemini uses URI as identifier
bytes=int(
response_object["sizeBytes"]
), # Gemini doesn't return file size
created_at=int(
time.mktime(
time.strptime(
response_object["createTime"].replace("Z", "+00:00"),
"%Y-%m-%dT%H:%M:%S.%f%z",
)
)
),
filename=response_object["displayName"],
object="file",
purpose="user_data", # Default to assistants as that's the main use case
status="uploaded",
status_details=None,
)
except Exception as e:
verbose_logger.exception(f"Error parsing file upload response: {str(e)}")
raise ValueError(f"Error parsing file upload response: {str(e)}")

View file

@ -361,6 +361,7 @@ class OllamaConfig(BaseConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -290,6 +290,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -1481,9 +1481,9 @@ class OpenAIFilesAPI(BaseLLM):
self, self,
create_file_data: CreateFileRequest, create_file_data: CreateFileRequest,
openai_client: AsyncOpenAI, openai_client: AsyncOpenAI,
) -> FileObject: ) -> OpenAIFileObject:
response = await openai_client.files.create(**create_file_data) response = await openai_client.files.create(**create_file_data)
return response return OpenAIFileObject(**response.model_dump())
def create_file( def create_file(
self, self,
@ -1495,7 +1495,7 @@ class OpenAIFilesAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: ) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1518,8 +1518,8 @@ class OpenAIFilesAPI(BaseLLM):
return self.acreate_file( # type: ignore return self.acreate_file( # type: ignore
create_file_data=create_file_data, openai_client=openai_client create_file_data=create_file_data, openai_client=openai_client
) )
response = openai_client.files.create(**create_file_data) response = cast(OpenAI, openai_client).files.create(**create_file_data)
return response return OpenAIFileObject(**response.model_dump())
async def afile_content( async def afile_content(
self, self,

View file

@ -170,6 +170,7 @@ def completion(
prediction_url = replicate_config.get_complete_url( prediction_url = replicate_config.get_complete_url(
api_base=api_base, api_base=api_base,
api_key=api_key,
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
@ -246,6 +247,7 @@ async def async_completion(
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
prediction_url = replicate_config.get_complete_url( prediction_url = replicate_config.get_complete_url(
api_base=api_base, api_base=api_base,
api_key=api_key,
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,

View file

@ -139,6 +139,7 @@ class ReplicateConfig(BaseConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -135,6 +135,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -53,6 +53,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -70,6 +70,7 @@ class TritonConfig(BaseConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -8,7 +8,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket_base import (
GCSLoggingConfig, GCSLoggingConfig,
) )
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import CreateFileRequest, FileObject from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .transformation import VertexAIFilesTransformation from .transformation import VertexAIFilesTransformation
@ -29,8 +29,6 @@ class VertexAIFilesHandler(GCSBucketBase):
llm_provider=LlmProviders.VERTEX_AI, llm_provider=LlmProviders.VERTEX_AI,
) )
pass
async def async_create_file( async def async_create_file(
self, self,
create_file_data: CreateFileRequest, create_file_data: CreateFileRequest,
@ -40,7 +38,7 @@ class VertexAIFilesHandler(GCSBucketBase):
vertex_location: Optional[str], vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
): ) -> OpenAIFileObject:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs={} kwargs={}
) )
@ -77,7 +75,7 @@ class VertexAIFilesHandler(GCSBucketBase):
vertex_location: Optional[str], vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: ) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
""" """
Creates a file on VertexAI GCS Bucket Creates a file on VertexAI GCS Bucket

View file

@ -9,7 +9,12 @@ from litellm.llms.vertex_ai.gemini.transformation import _transform_request_body
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
) )
from litellm.types.llms.openai import CreateFileRequest, FileObject, FileTypes, PathLike from litellm.types.llms.openai import (
CreateFileRequest,
FileTypes,
OpenAIFileObject,
PathLike,
)
class VertexAIFilesTransformation(VertexGeminiConfig): class VertexAIFilesTransformation(VertexGeminiConfig):
@ -142,7 +147,7 @@ class VertexAIFilesTransformation(VertexGeminiConfig):
def transform_gcs_bucket_response_to_openai_file_object( def transform_gcs_bucket_response_to_openai_file_object(
self, create_file_data: CreateFileRequest, gcs_upload_response: Dict[str, Any] self, create_file_data: CreateFileRequest, gcs_upload_response: Dict[str, Any]
) -> FileObject: ) -> OpenAIFileObject:
""" """
Transforms GCS Bucket upload file response to OpenAI FileObject Transforms GCS Bucket upload file response to OpenAI FileObject
""" """
@ -150,7 +155,7 @@ class VertexAIFilesTransformation(VertexGeminiConfig):
# Remove the last numeric ID from the path # Remove the last numeric ID from the path
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else "" gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
return FileObject( return OpenAIFileObject(
purpose=create_file_data.get("purpose", "batch"), purpose=create_file_data.get("purpose", "batch"),
id=f"gs://{gcs_id}", id=f"gs://{gcs_id}",
filename=gcs_upload_response.get("name", ""), filename=gcs_upload_response.get("name", ""),

View file

@ -41,6 +41,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -61,6 +61,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
## GET API URL ## GET API URL
api_base = watsonx_chat_transformation.get_complete_url( api_base = watsonx_chat_transformation.get_complete_url(
api_base=api_base, api_base=api_base,
api_key=api_key,
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,

View file

@ -80,6 +80,7 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -316,6 +316,7 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -52,6 +52,7 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,

View file

@ -31,8 +31,6 @@ litellm_settings:
callbacks: ["prometheus"] callbacks: ["prometheus"]
# json_logs: true # json_logs: true
router_settings: files_settings:
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE - custom_llm_provider: gemini
redis_host: os.environ/REDIS_HOST api_key: os.environ/GEMINI_API_KEY
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT

View file

@ -62,7 +62,7 @@ def get_files_provider_config(
if custom_llm_provider == "vertex_ai": if custom_llm_provider == "vertex_ai":
return None return None
if files_config is None: if files_config is None:
raise ValueError("files_config is not set, set it on your config.yaml file.") raise ValueError("files_settings is not set, set it on your config.yaml file.")
for setting in files_config: for setting in files_config:
if setting.get("custom_llm_provider") == custom_llm_provider: if setting.get("custom_llm_provider") == custom_llm_provider:
return setting return setting

View file

@ -0,0 +1,33 @@
from enum import Enum
from typing import Any, Dict, Iterable, List, Literal, Optional, Union
from typing_extensions import Required, TypedDict
class GeminiFilesState(Enum):
STATE_UNSPECIFIED = "STATE_UNSPECIFIED"
PROCESSING = "PROCESSING"
ACTIVE = "ACTIVE"
FAILED = "FAILED"
class GeminiFilesSource(Enum):
SOURCE_UNSPECIFIED = "SOURCE_UNSPECIFIED"
UPLOADED = "UPLOADED"
GENERATED = "GENERATED"
class GeminiCreateFilesResponseObject(TypedDict):
name: str
displayName: str
mimeType: str
sizeBytes: str
createTime: str
updateTime: str
expirationTime: str
sha256Hash: str
uri: str
state: GeminiFilesState
source: GeminiFilesSource
error: dict
metadata: dict

View file

@ -234,6 +234,59 @@ class Thread(BaseModel):
"""The object type, which is always `thread`.""" """The object type, which is always `thread`."""
OpenAICreateFileRequestOptionalParams = Literal["purpose",]
class OpenAIFileObject(BaseModel):
id: str
"""The file identifier, which can be referenced in the API endpoints."""
bytes: int
"""The size of the file, in bytes."""
created_at: int
"""The Unix timestamp (in seconds) for when the file was created."""
filename: str
"""The name of the file."""
object: Literal["file"]
"""The object type, which is always `file`."""
purpose: Literal[
"assistants",
"assistants_output",
"batch",
"batch_output",
"fine-tune",
"fine-tune-results",
"vision",
"user_data",
]
"""The intended purpose of the file.
Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`,
`fine-tune`, `fine-tune-results`, `vision`, and `user_data`.
"""
status: Literal["uploaded", "processed", "error"]
"""Deprecated.
The current status of the file, which can be either `uploaded`, `processed`, or
`error`.
"""
expires_at: Optional[int] = None
"""The Unix timestamp (in seconds) for when the file will expire."""
status_details: Optional[str] = None
"""Deprecated.
For details on why a fine-tuning training file failed validation, see the
`error` field on `fine_tuning.job`.
"""
# OpenAI Files Types # OpenAI Files Types
class CreateFileRequest(TypedDict, total=False): class CreateFileRequest(TypedDict, total=False):
""" """

View file

@ -57,6 +57,8 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.litellm_core_utils.audio_utils.utils import litellm.litellm_core_utils.audio_utils.utils
import litellm.litellm_core_utils.json_validation_rule import litellm.litellm_core_utils.json_validation_rule
import litellm.llms
import litellm.llms.gemini
from litellm.caching._internal_lru_cache import lru_cache_wrapper from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
@ -207,6 +209,7 @@ from litellm.llms.base_llm.base_utils import (
from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.files.transformation import BaseFilesConfig
from litellm.llms.base_llm.image_variations.transformation import ( from litellm.llms.base_llm.image_variations.transformation import (
BaseImageVariationConfig, BaseImageVariationConfig,
) )
@ -1259,6 +1262,7 @@ def client(original_function): # noqa: PLR0915
logging_obj, kwargs = function_setup( logging_obj, kwargs = function_setup(
original_function.__name__, rules_obj, start_time, *args, **kwargs original_function.__name__, rules_obj, start_time, *args, **kwargs
) )
kwargs["litellm_logging_obj"] = logging_obj kwargs["litellm_logging_obj"] = logging_obj
## LOAD CREDENTIALS ## LOAD CREDENTIALS
load_credentials_from_list(kwargs) load_credentials_from_list(kwargs)
@ -6426,6 +6430,19 @@ class ProviderConfigManager:
return litellm.TopazImageVariationConfig() return litellm.TopazImageVariationConfig()
return None return None
@staticmethod
def get_provider_files_config(
model: str,
provider: LlmProviders,
) -> Optional[BaseFilesConfig]:
if LlmProviders.GEMINI == provider:
from litellm.llms.gemini.files.transformation import (
GoogleAIStudioFilesHandler, # experimental approach, to reduce bloat on __init__.py
)
return GoogleAIStudioFilesHandler()
return None
def get_end_user_id_for_cost_tracking( def get_end_user_id_for_cost_tracking(
litellm_params: dict, litellm_params: dict,

View file

@ -25,6 +25,7 @@ def test_get_complete_url_basic(bedrock_transformer):
"""Test basic URL construction for non-streaming request""" """Test basic URL construction for non-streaming request"""
url = bedrock_transformer.get_complete_url( url = bedrock_transformer.get_complete_url(
api_base="https://bedrock-runtime.us-east-1.amazonaws.com", api_base="https://bedrock-runtime.us-east-1.amazonaws.com",
api_key=None,
model="anthropic.claude-v2", model="anthropic.claude-v2",
optional_params={}, optional_params={},
stream=False, stream=False,
@ -41,6 +42,7 @@ def test_get_complete_url_streaming(bedrock_transformer):
"""Test URL construction for streaming request""" """Test URL construction for streaming request"""
url = bedrock_transformer.get_complete_url( url = bedrock_transformer.get_complete_url(
api_base="https://bedrock-runtime.us-east-1.amazonaws.com", api_base="https://bedrock-runtime.us-east-1.amazonaws.com",
api_key=None,
model="anthropic.claude-v2", model="anthropic.claude-v2",
optional_params={}, optional_params={},
stream=True, stream=True,