Support CRUD endpoints for Managed Files (#9924)

* fix(openai.py): ensure openai file object shows up on logs

* fix(managed_files.py): return unified file id as b64 str

allows retrieve file id to work as expected

* fix(managed_files.py): apply decoded file id transformation

* fix: add unit test for file id + decode logic

* fix: initial commit for litellm_proxy support with CRUD Endpoints

* fix(managed_files.py): support retrieve file operation

* fix(managed_files.py): support for DELETE endpoint for files

* fix(managed_files.py): retrieve file content support

supports retrieve file content api from openai

* fix: fix linting error

* test: update tests

* fix: fix linting error

* fix(files/main.py): pass litellm params to azure route

* test: fix test
This commit is contained in:
Krish Dholakia 2025-04-11 21:48:27 -07:00 committed by GitHub
parent 3e427e26c9
commit 3ca82c22b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 783 additions and 86 deletions

View file

@ -159,6 +159,51 @@ async def create_file_for_each_model(
return response
async def route_create_file(
llm_router: Optional[Router],
_create_file_request: CreateFileRequest,
purpose: OpenAIFilesPurpose,
proxy_logging_obj: ProxyLogging,
user_api_key_dict: UserAPIKeyAuth,
target_model_names_list: List[str],
is_router_model: bool,
router_model: Optional[str],
custom_llm_provider: str,
) -> OpenAIFileObject:
if (
litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model
and router_model is not None
):
response = await _deprecated_loadbalanced_create_file(
llm_router=llm_router,
router_model=router_model,
_create_file_request=_create_file_request,
)
elif target_model_names_list:
response = await create_file_for_each_model(
llm_router=llm_router,
_create_file_request=_create_file_request,
target_model_names_list=target_model_names_list,
purpose=purpose,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
)
else:
# get configs for custom_llm_provider
llm_provider_config = get_files_provider_config(
custom_llm_provider=custom_llm_provider
)
if llm_provider_config is not None:
# add llm_provider_config to data
_create_file_request.update(llm_provider_config)
_create_file_request.pop("custom_llm_provider", None) # type: ignore
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
return response
@router.post(
"/{provider}/v1/files",
dependencies=[Depends(user_api_key_auth)],
@ -267,37 +312,17 @@ async def create_file(
file=file_data, purpose=cast(CREATE_FILE_REQUESTS_PURPOSE, purpose), **data
)
response: Optional[OpenAIFileObject] = None
if (
litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model
and router_model is not None
):
response = await _deprecated_loadbalanced_create_file(
llm_router=llm_router,
router_model=router_model,
_create_file_request=_create_file_request,
)
elif target_model_names_list:
response = await create_file_for_each_model(
llm_router=llm_router,
_create_file_request=_create_file_request,
target_model_names_list=target_model_names_list,
purpose=purpose,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
)
else:
# get configs for custom_llm_provider
llm_provider_config = get_files_provider_config(
custom_llm_provider=custom_llm_provider
)
if llm_provider_config is not None:
# add llm_provider_config to data
_create_file_request.update(llm_provider_config)
_create_file_request.pop("custom_llm_provider", None) # type: ignore
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
response = await route_create_file(
llm_router=llm_router,
_create_file_request=_create_file_request,
purpose=purpose,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
target_model_names_list=target_model_names_list,
is_router_model=is_router_model,
router_model=router_model,
custom_llm_provider=custom_llm_provider,
)
if response is None:
raise HTTPException(
@ -311,6 +336,13 @@ async def create_file(
)
)
## POST CALL HOOKS ###
_response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
if _response is not None and isinstance(_response, OpenAIFileObject):
response = _response
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
@ -392,6 +424,7 @@ async def get_file_content(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
version,
@ -414,9 +447,40 @@ async def get_file_content(
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
response = await litellm.afile_content(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
## check if file_id is a litellm managed file
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
)
if is_base64_unified_file_id:
managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
proxy_logging_obj.get_proxy_hook("managed_files"),
)
if managed_files_obj is None:
raise ProxyException(
message="Managed files hook not found",
type="None",
param="None",
code=500,
)
if llm_router is None:
raise ProxyException(
message="LLM Router not found",
type="None",
param="None",
code=500,
)
response = await managed_files_obj.afile_content(
file_id=file_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
llm_router=llm_router,
**data,
)
else:
response = await litellm.afile_content(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
asyncio.create_task(
@ -539,10 +603,33 @@ async def get_file(
version=version,
proxy_config=proxy_config,
)
response = await litellm.afile_retrieve(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
## check if file_id is a litellm managed file
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
)
if is_base64_unified_file_id:
managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
proxy_logging_obj.get_proxy_hook("managed_files"),
)
if managed_files_obj is None:
raise ProxyException(
message="Managed files hook not found",
type="None",
param="None",
code=500,
)
response = await managed_files_obj.afile_retrieve(
file_id=file_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
else:
response = await litellm.afile_retrieve(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
@ -634,6 +721,7 @@ async def delete_file(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
version,
@ -656,10 +744,41 @@ async def delete_file(
proxy_config=proxy_config,
)
response = await litellm.afile_delete(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
## check if file_id is a litellm managed file
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
)
if is_base64_unified_file_id:
managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
proxy_logging_obj.get_proxy_hook("managed_files"),
)
if managed_files_obj is None:
raise ProxyException(
message="Managed files hook not found",
type="None",
param="None",
code=500,
)
if llm_router is None:
raise ProxyException(
message="LLM Router not found",
type="None",
param="None",
code=500,
)
response = await managed_files_obj.afile_delete(
file_id=file_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
llm_router=llm_router,
**data,
)
else:
response = await litellm.afile_delete(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(