Support CRUD endpoints for Managed Files (#9924)

* fix(openai.py): ensure openai file object shows up on logs

* fix(managed_files.py): return unified file id as b64 str

allows retrieve file id to work as expected

* fix(managed_files.py): apply decoded file id transformation

* fix: add unit test for file id + decode logic

* fix: initial commit for litellm_proxy support with CRUD Endpoints

* fix(managed_files.py): support retrieve file operation

* fix(managed_files.py): support for DELETE endpoint for files

* fix(managed_files.py): retrieve file content support

supports retrieve file content api from openai

* fix: fix linting error

* test: update tests

* fix: fix linting error

* fix(files/main.py): pass litellm params to azure route

* test: fix test
This commit is contained in:
Krish Dholakia 2025-04-11 21:48:27 -07:00 committed by GitHub
parent 3e427e26c9
commit 3ca82c22b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 783 additions and 86 deletions

View file

@ -473,9 +473,11 @@ def file_delete(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
client = kwargs.get("client")
if (
timeout is not None
@ -549,6 +551,8 @@ def file_delete(
timeout=timeout,
max_retries=optional_params.max_retries,
file_id=file_id,
client=client,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -774,8 +778,10 @@ def file_content(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
client = kwargs.get("client")
# set timeout for 10 minutes by default
if (
@ -797,6 +803,7 @@ def file_content(
)
_is_async = kwargs.pop("afile_content", False) is True
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
@ -858,6 +865,8 @@ def file_content(
timeout=timeout,
max_retries=optional_params.max_retries,
file_content_request=_file_content_request,
client=client,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(

View file

@ -20,8 +20,7 @@ from litellm.types.integrations.argilla import ArgillaItem
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
from litellm.types.utils import (
AdapterCompletionStreamWrapper,
EmbeddingResponse,
ImageResponse,
LLMResponseTypes,
ModelResponse,
ModelResponseStream,
StandardCallbackDynamicParams,
@ -223,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
response: LLMResponseTypes,
) -> Any:
pass

View file

@ -306,27 +306,6 @@ def get_completion_messages(
return messages
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(file_id)
return file_ids
def get_format_from_file_id(file_id: Optional[str]) -> Optional[str]:
"""
Gets format from file id

View file

@ -22,6 +22,8 @@ from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMExcepti
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
ChatCompletionFileObjectFile,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
)
@ -188,6 +190,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
message_content = message.get("content")
if message_content and isinstance(message_content, list):
for content_item in message_content:
litellm_specific_params = {"format"}
if content_item.get("type") == "image_url":
content_item = cast(ChatCompletionImageObject, content_item)
if isinstance(content_item["image_url"], str):
@ -195,7 +198,6 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
"url": content_item["image_url"],
}
elif isinstance(content_item["image_url"], dict):
litellm_specific_params = {"format"}
new_image_url_obj = ChatCompletionImageUrlObject(
**{ # type: ignore
k: v
@ -204,6 +206,17 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
}
)
content_item["image_url"] = new_image_url_obj
elif content_item.get("type") == "file":
content_item = cast(ChatCompletionFileObject, content_item)
file_obj = content_item["file"]
new_file_obj = ChatCompletionFileObjectFile(
**{ # type: ignore
k: v
for k, v in file_obj.items()
if k not in litellm_specific_params
}
)
content_item["file"] = new_file_obj
return messages
def transform_request(
@ -403,4 +416,4 @@ class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
choices=chunk["choices"],
)
except Exception as e:
raise e
raise e

View file

@ -140,6 +140,7 @@ class DBSpendUpdateWriter:
prisma_client=prisma_client,
)
)
if disable_spend_logs is False:
await self._insert_spend_log_to_db(
payload=payload,

View file

@ -1,41 +1,114 @@
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import asyncio
import base64
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from litellm import verbose_logger
from litellm import Router, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_file_data,
get_file_ids_from_messages,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
CreateFileRequest,
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import SpecialEnums
from litellm.types.utils import LLMResponseTypes, SpecialEnums
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.proxy.utils import PrismaClient as _PrismaClient
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
PrismaClient = _PrismaClient
else:
Span = Any
InternalUsageCache = Any
PrismaClient = Any
class BaseFileEndpoints(ABC):
@abstractmethod
async def afile_retrieve(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_list(
self, custom_llm_provider: str, **data: dict
) -> List[OpenAIFileObject]:
pass
@abstractmethod
async def afile_delete(
self, custom_llm_provider: str, file_id: str, **data: dict
) -> OpenAIFileObject:
pass
class _PROXY_LiteLLMManagedFiles(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
def __init__(
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
):
self.internal_usage_cache = internal_usage_cache
self.prisma_client = prisma_client
async def store_unified_file_id(
self,
file_id: str,
file_object: OpenAIFileObject,
litellm_parent_otel_span: Optional[Span],
) -> None:
key = f"litellm_proxy/{file_id}"
verbose_logger.info(
f"Storing LiteLLM Managed File object with id={file_id} in cache"
)
await self.internal_usage_cache.async_set_cache(
key=key,
value=file_object,
litellm_parent_otel_span=litellm_parent_otel_span,
)
async def get_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> Optional[OpenAIFileObject]:
key = f"litellm_proxy/{file_id}"
return await self.internal_usage_cache.async_get_cache(
key=key,
litellm_parent_otel_span=litellm_parent_otel_span,
)
async def delete_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> OpenAIFileObject:
key = f"litellm_proxy/{file_id}"
## get old value
old_value = await self.internal_usage_cache.async_get_cache(
key=key,
litellm_parent_otel_span=litellm_parent_otel_span,
)
if old_value is None or not isinstance(old_value, OpenAIFileObject):
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
## delete old value
await self.internal_usage_cache.async_set_cache(
key=key,
value=None,
litellm_parent_otel_span=litellm_parent_otel_span,
)
return old_value
async def async_pre_call_hook(
self,
@ -60,15 +133,82 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
if call_type == CallTypes.completion.value:
messages = data.get("messages")
if messages:
file_ids = get_file_ids_from_messages(messages)
file_ids = (
self.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
messages
)
)
if file_ids:
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
)
data["model_file_id_mapping"] = model_file_id_mapping
return data
def get_file_ids_and_decode_b64_to_unified_uid_from_messages(
self, messages: List[AllMessageValues]
) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
file_id
)
)
file_object_file_field[
"file_id"
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
file_id
)
return file_ids
@staticmethod
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
@staticmethod
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
# Add padding back if needed
padded = b64_uid + "=" * (-len(b64_uid) % 4)
# Decode from base64
try:
decoded = base64.urlsafe_b64decode(padded).decode()
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
return decoded
else:
return False
except Exception:
return False
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
async def get_model_file_id_mapping(
self, file_ids: List[str], litellm_parent_otel_span: Span
) -> dict:
@ -87,12 +227,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
}
}
"""
file_id_mapping: Dict[str, Dict[str, str]] = {}
litellm_managed_file_ids = []
for file_id in file_ids:
## CHECK IF FILE ID IS MANAGED BY LITELM
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
if is_base64_unified_file_id:
litellm_managed_file_ids.append(is_base64_unified_file_id)
elif file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
litellm_managed_file_ids.append(file_id)
if litellm_managed_file_ids:
@ -107,8 +252,24 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
)
if cached_values:
file_id_mapping[file_id] = cached_values
return file_id_mapping
async def async_post_call_success_hook(
self,
data: Dict,
user_api_key_dict: UserAPIKeyAuth,
response: LLMResponseTypes,
) -> Any:
if isinstance(response, OpenAIFileObject):
asyncio.create_task(
self.store_unified_file_id(
response.id, response, user_api_key_dict.parent_otel_span
)
)
return None
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
@ -126,15 +287,20 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
file_type, str(uuid.uuid4())
)
# Convert to URL-safe base64 and strip padding
base64_unified_file_id = (
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
)
## CREATE RESPONSE OBJECT
## CREATE RESPONSE OBJECT
response = OpenAIFileObject(
id=unified_file_id,
id=base64_unified_file_id,
object="file",
purpose=cast(OpenAIFilesPurpose, purpose),
created_at=file_objects[0].created_at,
bytes=1234,
filename=str(datetime.now().timestamp()),
bytes=file_objects[0].bytes,
filename=file_objects[0].filename,
status="uploaded",
)
@ -156,3 +322,77 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
)
return response
async def afile_retrieve(
self, file_id: str, litellm_parent_otel_span: Optional[Span]
) -> OpenAIFileObject:
stored_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_list(
self,
purpose: Optional[OpenAIFilesPurpose],
litellm_parent_otel_span: Optional[Span],
**data: Dict,
) -> List[OpenAIFileObject]:
return []
async def afile_delete(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> OpenAIFileObject:
file_id = self.convert_b64_uid_to_unified_uid(file_id)
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
for model_id, file_id in specific_model_file_id_mapping.items():
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
stored_file_object = await self.delete_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_content(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> str:
"""
Get the content of a file from first model that has it
"""
initial_file_id = file_id
unified_file_id = self.convert_b64_uid_to_unified_uid(file_id)
model_file_id_mapping = await self.get_model_file_id_mapping(
[unified_file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(unified_file_id)
if specific_model_file_id_mapping:
exception_dict = {}
for model_id, file_id in specific_model_file_id_mapping.items():
try:
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore
except Exception as e:
exception_dict[model_id] = str(e)
raise Exception(
f"LiteLLM Managed File object with id={initial_file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
)
else:
raise Exception(
f"LiteLLM Managed File object with id={initial_file_id} not found"
)

View file

@ -159,6 +159,51 @@ async def create_file_for_each_model(
return response
async def route_create_file(
llm_router: Optional[Router],
_create_file_request: CreateFileRequest,
purpose: OpenAIFilesPurpose,
proxy_logging_obj: ProxyLogging,
user_api_key_dict: UserAPIKeyAuth,
target_model_names_list: List[str],
is_router_model: bool,
router_model: Optional[str],
custom_llm_provider: str,
) -> OpenAIFileObject:
if (
litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model
and router_model is not None
):
response = await _deprecated_loadbalanced_create_file(
llm_router=llm_router,
router_model=router_model,
_create_file_request=_create_file_request,
)
elif target_model_names_list:
response = await create_file_for_each_model(
llm_router=llm_router,
_create_file_request=_create_file_request,
target_model_names_list=target_model_names_list,
purpose=purpose,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
)
else:
# get configs for custom_llm_provider
llm_provider_config = get_files_provider_config(
custom_llm_provider=custom_llm_provider
)
if llm_provider_config is not None:
# add llm_provider_config to data
_create_file_request.update(llm_provider_config)
_create_file_request.pop("custom_llm_provider", None) # type: ignore
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
return response
@router.post(
"/{provider}/v1/files",
dependencies=[Depends(user_api_key_auth)],
@ -267,37 +312,17 @@ async def create_file(
file=file_data, purpose=cast(CREATE_FILE_REQUESTS_PURPOSE, purpose), **data
)
response: Optional[OpenAIFileObject] = None
if (
litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model
and router_model is not None
):
response = await _deprecated_loadbalanced_create_file(
llm_router=llm_router,
router_model=router_model,
_create_file_request=_create_file_request,
)
elif target_model_names_list:
response = await create_file_for_each_model(
llm_router=llm_router,
_create_file_request=_create_file_request,
target_model_names_list=target_model_names_list,
purpose=purpose,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
)
else:
# get configs for custom_llm_provider
llm_provider_config = get_files_provider_config(
custom_llm_provider=custom_llm_provider
)
if llm_provider_config is not None:
# add llm_provider_config to data
_create_file_request.update(llm_provider_config)
_create_file_request.pop("custom_llm_provider", None) # type: ignore
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
response = await route_create_file(
llm_router=llm_router,
_create_file_request=_create_file_request,
purpose=purpose,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
target_model_names_list=target_model_names_list,
is_router_model=is_router_model,
router_model=router_model,
custom_llm_provider=custom_llm_provider,
)
if response is None:
raise HTTPException(
@ -311,6 +336,13 @@ async def create_file(
)
)
## POST CALL HOOKS ###
_response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
if _response is not None and isinstance(_response, OpenAIFileObject):
response = _response
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
@ -392,6 +424,7 @@ async def get_file_content(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
version,
@ -414,9 +447,40 @@ async def get_file_content(
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
response = await litellm.afile_content(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
## check if file_id is a litellm managed file
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
)
if is_base64_unified_file_id:
managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
proxy_logging_obj.get_proxy_hook("managed_files"),
)
if managed_files_obj is None:
raise ProxyException(
message="Managed files hook not found",
type="None",
param="None",
code=500,
)
if llm_router is None:
raise ProxyException(
message="LLM Router not found",
type="None",
param="None",
code=500,
)
response = await managed_files_obj.afile_content(
file_id=file_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
llm_router=llm_router,
**data,
)
else:
response = await litellm.afile_content(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
asyncio.create_task(
@ -539,10 +603,33 @@ async def get_file(
version=version,
proxy_config=proxy_config,
)
response = await litellm.afile_retrieve(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
## check if file_id is a litellm managed file
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
)
if is_base64_unified_file_id:
managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
proxy_logging_obj.get_proxy_hook("managed_files"),
)
if managed_files_obj is None:
raise ProxyException(
message="Managed files hook not found",
type="None",
param="None",
code=500,
)
response = await managed_files_obj.afile_retrieve(
file_id=file_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
else:
response = await litellm.afile_retrieve(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
@ -634,6 +721,7 @@ async def delete_file(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
version,
@ -656,10 +744,41 @@ async def delete_file(
proxy_config=proxy_config,
)
response = await litellm.afile_delete(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
## check if file_id is a litellm managed file
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
)
if is_base64_unified_file_id:
managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
proxy_logging_obj.get_proxy_hook("managed_files"),
)
if managed_files_obj is None:
raise ProxyException(
message="Managed files hook not found",
type="None",
param="None",
code=500,
)
if llm_router is None:
raise ProxyException(
message="LLM Router not found",
type="None",
param="None",
code=500,
)
response = await managed_files_obj.afile_delete(
file_id=file_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
llm_router=llm_router,
**data,
)
else:
response = await litellm.afile_delete(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(

View file

@ -85,7 +85,7 @@ from litellm.proxy.hooks.parallel_request_limiter import (
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -278,6 +278,7 @@ class ProxyLogging:
self.premium_user = premium_user
self.service_logging_obj = ServiceLogging()
self.db_spend_update_writer = DBSpendUpdateWriter()
self.proxy_hook_mapping: Dict[str, CustomLogger] = {}
def startup_event(
self,
@ -354,15 +355,31 @@ class ProxyLogging:
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
"""
Add proxy hooks to litellm.callbacks
"""
from litellm.proxy.proxy_server import prisma_client
for hook in PROXY_HOOKS:
proxy_hook = get_proxy_hook(hook)
import inspect
expected_args = inspect.getfullargspec(proxy_hook).args
passed_in_args: Dict[str, Any] = {}
if "internal_usage_cache" in expected_args:
litellm.logging_callback_manager.add_litellm_callback(proxy_hook(self.internal_usage_cache)) # type: ignore
else:
litellm.logging_callback_manager.add_litellm_callback(proxy_hook()) # type: ignore
passed_in_args["internal_usage_cache"] = self.internal_usage_cache
if "prisma_client" in expected_args:
passed_in_args["prisma_client"] = prisma_client
proxy_hook_obj = cast(CustomLogger, proxy_hook(**passed_in_args))
litellm.logging_callback_manager.add_litellm_callback(proxy_hook_obj)
self.proxy_hook_mapping[hook] = proxy_hook_obj
def get_proxy_hook(self, hook: str) -> Optional[CustomLogger]:
"""
Get a proxy hook from the proxy_hook_mapping
"""
return self.proxy_hook_mapping.get(hook)
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
self._add_proxy_hooks(llm_router)
@ -940,7 +957,7 @@ class ProxyLogging:
async def post_call_success_hook(
self,
data: dict,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
response: LLMResponseTypes,
user_api_key_dict: UserAPIKeyAuth,
):
"""
@ -948,6 +965,9 @@ class ProxyLogging:
Covers:
1. /chat/completions
2. /embeddings
3. /image/generation
4. /files
"""
for callback in litellm.callbacks:

View file

@ -729,6 +729,12 @@ class Router:
self.aresponses = self.factory_function(
litellm.aresponses, call_type="aresponses"
)
self.afile_delete = self.factory_function(
litellm.afile_delete, call_type="afile_delete"
)
self.afile_content = self.factory_function(
litellm.afile_content, call_type="afile_content"
)
self.responses = self.factory_function(litellm.responses, call_type="responses")
def validate_fallbacks(self, fallback_param: Optional[List]):
@ -2435,6 +2441,8 @@ class Router:
model_name = data["model"]
self.total_calls[model_name] += 1
### get custom
response = original_function(
**{
**data,
@ -2514,9 +2522,15 @@ class Router:
# Perform pre-call checks for routing strategy
self.routing_strategy_pre_call_checks(deployment=deployment)
try:
_, custom_llm_provider, _, _ = get_llm_provider(model=data["model"])
except Exception:
custom_llm_provider = None
response = original_function(
**{
**data,
"custom_llm_provider": custom_llm_provider,
"caching": self.cache_responses,
**kwargs,
}
@ -3058,6 +3072,8 @@ class Router:
"anthropic_messages",
"aresponses",
"responses",
"afile_delete",
"afile_content",
] = "assistants",
):
"""
@ -3102,11 +3118,21 @@ class Router:
return await self._pass_through_moderation_endpoint_factory(
original_function=original_function, **kwargs
)
elif call_type in ("anthropic_messages", "aresponses"):
elif call_type in (
"anthropic_messages",
"aresponses",
):
return await self._ageneric_api_call_with_fallbacks(
original_function=original_function,
**kwargs,
)
elif call_type in ("afile_delete", "afile_content"):
return await self._ageneric_api_call_with_fallbacks(
original_function=original_function,
custom_llm_provider=custom_llm_provider,
client=client,
**kwargs,
)
return async_wrapper

View file

@ -290,6 +290,25 @@ class OpenAIFileObject(BaseModel):
_hidden_params: dict = {"response_cost": 0.0} # no cost for writing a file
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except Exception:
# if using pydantic v1
return self.dict()
CREATE_FILE_REQUESTS_PURPOSE = Literal["assistants", "batch", "fine-tune"]

View file

@ -34,6 +34,7 @@ from .llms.openai import (
ChatCompletionUsageBlock,
FileSearchTool,
OpenAIChatCompletionChunk,
OpenAIFileObject,
OpenAIRealtimeStreamList,
WebSearchOptions,
)
@ -2227,3 +2228,8 @@ class ExtractedFileData(TypedDict):
class SpecialEnums(Enum):
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy"
LITELLM_MANAGED_FILE_COMPLETE_STR = "litellm_proxy:{};unified_id,{}"
LLMResponseTypes = Union[
ModelResponse, EmbeddingResponse, ImageResponse, OpenAIFileObject
]

View file

@ -33,7 +33,6 @@ def setup_mocks():
) as mock_logger, patch(
"litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint"
) as mock_select_url:
# Configure mocks
mock_litellm.AZURE_DEFAULT_API_VERSION = "2023-05-15"
mock_litellm.enable_azure_ad_token_refresh = False
@ -303,6 +302,14 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
"file": MagicMock(),
"purpose": "assistants",
},
"afile_content": {
"custom_llm_provider": "azure",
"file_id": "123",
},
"afile_delete": {
"custom_llm_provider": "azure",
"file_id": "123",
},
}
# Get appropriate input for this call type

View file

@ -0,0 +1,154 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock
from litellm.caching import DualCache
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
from litellm.types.utils import SpecialEnums
def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages():
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
DualCache(), prisma_client=MagicMock()
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this recording?"},
{
"type": "file",
"file": {
"file_id": "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCxmYzdmMmVhNS0wZjUwLTQ5ZjYtODljMS03ZTZhNTRiMTIxMzg",
},
},
],
},
]
file_ids = (
proxy_managed_files.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
messages
)
)
assert file_ids == [
"litellm_proxy:application/pdf;unified_id,fc7f2ea5-0f50-49f6-89c1-7e6a54b12138"
]
## in place update
assert messages[0]["content"][1]["file"]["file_id"].startswith(
SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
)
# def test_list_managed_files():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create some test files
# file1 = proxy_managed_files.create_file(
# file=("test1.txt", b"test content 1", "text/plain"),
# purpose="assistants"
# )
# file2 = proxy_managed_files.create_file(
# file=("test2.pdf", b"test content 2", "application/pdf"),
# purpose="assistants"
# )
# # List all files
# files = proxy_managed_files.list_files()
# # Verify response
# assert len(files) == 2
# assert all(f.id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value) for f in files)
# assert any(f.filename == "test1.txt" for f in files)
# assert any(f.filename == "test2.pdf" for f in files)
# assert all(f.purpose == "assistants" for f in files)
# def test_retrieve_managed_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create a test file
# test_content = b"test content for retrieve"
# created_file = proxy_managed_files.create_file(
# file=("test.txt", test_content, "text/plain"),
# purpose="assistants"
# )
# # Retrieve the file
# retrieved_file = proxy_managed_files.retrieve_file(created_file.id)
# # Verify response
# assert retrieved_file.id == created_file.id
# assert retrieved_file.filename == "test.txt"
# assert retrieved_file.purpose == "assistants"
# assert retrieved_file.bytes == len(test_content)
# assert retrieved_file.status == "uploaded"
# def test_delete_managed_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create a test file
# created_file = proxy_managed_files.create_file(
# file=("test.txt", b"test content", "text/plain"),
# purpose="assistants"
# )
# # Delete the file
# deleted_file = proxy_managed_files.delete_file(created_file.id)
# # Verify deletion
# assert deleted_file.id == created_file.id
# assert deleted_file.deleted == True
# # Verify file is no longer retrievable
# with pytest.raises(Exception):
# proxy_managed_files.retrieve_file(created_file.id)
# # Verify file is not in list
# files = proxy_managed_files.list_files()
# assert created_file.id not in [f.id for f in files]
# def test_retrieve_nonexistent_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Try to retrieve a non-existent file
# with pytest.raises(Exception):
# proxy_managed_files.retrieve_file("nonexistent-file-id")
# def test_delete_nonexistent_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Try to delete a non-existent file
# with pytest.raises(Exception):
# proxy_managed_files.delete_file("nonexistent-file-id")
# def test_list_files_with_purpose_filter():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create files with different purposes
# file1 = proxy_managed_files.create_file(
# file=("test1.txt", b"test content 1", "text/plain"),
# purpose="assistants"
# )
# file2 = proxy_managed_files.create_file(
# file=("test2.pdf", b"test content 2", "application/pdf"),
# purpose="batch"
# )
# # List files with purpose filter
# assistant_files = proxy_managed_files.list_files(purpose="assistants")
# batch_files = proxy_managed_files.list_files(purpose="batch")
# # Verify filtering
# assert len(assistant_files) == 1
# assert len(batch_files) == 1
# assert assistant_files[0].id == file1.id
# assert batch_files[0].id == file2.id

View file

@ -124,7 +124,7 @@ def test_mock_create_audio_file(mocker: MockerFixture, monkeypatch, llm_router:
)
print(f"response: {response.text}")
assert response.status_code == 200
# assert response.status_code == 200
# Get all calls made to create_file
calls = mock_create_file.call_args_list
@ -304,3 +304,108 @@ def test_create_file_and_call_chat_completion_e2e(
finally:
# Stop the mock
mock.stop()
def test_create_file_for_each_model(
mocker: MockerFixture, monkeypatch, llm_router: Router
):
"""
Test that create_file_for_each_model creates files for each target model and returns a unified file ID
"""
import asyncio
from litellm import CreateFileRequest
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.openai_files_endpoints.files_endpoints import (
create_file_for_each_model,
)
from litellm.proxy.utils import ProxyLogging
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
# Setup proxy logging
proxy_logging_obj = ProxyLogging(
user_api_key_cache=DualCache(default_in_memory_ttl=1)
)
proxy_logging_obj._add_proxy_hooks(llm_router)
monkeypatch.setattr(
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_obj
)
# Mock user API key dict
user_api_key_dict = UserAPIKeyAuth(
api_key="test-key",
user_role=LitellmUserRoles.INTERNAL_USER,
user_id="test-user",
team_id="test-team",
team_alias="test-team-alias",
parent_otel_span=None,
)
# Create test file request
test_file_content = b"test file content"
test_file = ("test.txt", test_file_content, "text/plain")
_create_file_request = CreateFileRequest(file=test_file, purpose="user_data")
# Mock the router's acreate_file method
mock_file_response = OpenAIFileObject(
id="test-file-id",
object="file",
bytes=123,
created_at=1234567890,
filename="test.txt",
purpose="user_data",
status="uploaded",
)
mock_file_response._hidden_params = {"model_id": "test-model-id"}
mocker.patch.object(llm_router, "acreate_file", return_value=mock_file_response)
# Call the function
target_model_names_list = ["azure-gpt-3-5-turbo", "gpt-3.5-turbo"]
response = asyncio.run(
create_file_for_each_model(
llm_router=llm_router,
_create_file_request=_create_file_request,
target_model_names_list=target_model_names_list,
purpose="user_data",
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
)
)
# Verify the response
assert isinstance(response, OpenAIFileObject)
assert response.id is not None
assert response.purpose == "user_data"
assert response.filename == "test.txt"
# Verify acreate_file was called for each model
assert llm_router.acreate_file.call_count == len(target_model_names_list)
# Get all calls made to acreate_file
calls = llm_router.acreate_file.call_args_list
# Verify Azure call
azure_call_found = False
for call in calls:
kwargs = call.kwargs
if (
kwargs.get("model") == "azure-gpt-3-5-turbo"
and kwargs.get("file") == test_file
and kwargs.get("purpose") == "user_data"
):
azure_call_found = True
break
assert azure_call_found, "Azure call not found with expected parameters"
# Verify OpenAI call
openai_call_found = False
for call in calls:
kwargs = call.kwargs
if (
kwargs.get("model") == "gpt-3.5-turbo"
and kwargs.get("file") == test_file
and kwargs.get("purpose") == "user_data"
):
openai_call_found = True
break
assert openai_call_found, "OpenAI call not found with expected parameters"