litellm-mirror/litellm/proxy/hooks/managed_files.py

364 lines
13 KiB
Python

# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import asyncio
import base64
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from litellm import Router, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
CreateFileRequest,
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import LLMResponseTypes, SpecialEnums
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.proxy.utils import PrismaClient as _PrismaClient
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
PrismaClient = _PrismaClient
else:
Span = Any
InternalUsageCache = Any
PrismaClient = Any
class BaseFileEndpoints(ABC):
@abstractmethod
async def afile_retrieve(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_list(
self, custom_llm_provider: str, **data: dict
) -> List[OpenAIFileObject]:
pass
@abstractmethod
async def afile_delete(
self, custom_llm_provider: str, file_id: str, **data: dict
) -> OpenAIFileObject:
pass
class _PROXY_LiteLLMManagedFiles(CustomLogger):
# Class variables or attributes
def __init__(
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
):
self.internal_usage_cache = internal_usage_cache
self.prisma_client = prisma_client
async def store_unified_file_id(
self,
file_id: str,
file_object: OpenAIFileObject,
litellm_parent_otel_span: Optional[Span],
) -> None:
key = f"litellm_proxy/{file_id}"
verbose_logger.info(
f"Storing LiteLLM Managed File object with id={file_id} in cache"
)
await self.internal_usage_cache.async_set_cache(
key=key,
value=file_object,
litellm_parent_otel_span=litellm_parent_otel_span,
)
async def get_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> Optional[OpenAIFileObject]:
key = f"litellm_proxy/{file_id}"
return await self.internal_usage_cache.async_get_cache(
key=key,
litellm_parent_otel_span=litellm_parent_otel_span,
)
async def delete_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> OpenAIFileObject:
key = f"litellm_proxy/{file_id}"
## get old value
old_value = await self.internal_usage_cache.async_get_cache(
key=key,
litellm_parent_otel_span=litellm_parent_otel_span,
)
if old_value is None or not isinstance(old_value, OpenAIFileObject):
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
## delete old value
await self.internal_usage_cache.async_set_cache(
key=key,
value=None,
litellm_parent_otel_span=litellm_parent_otel_span,
)
return old_value
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: Dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Union[Exception, str, Dict, None]:
"""
- Detect litellm_proxy/ file_id
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
"""
if call_type == CallTypes.completion.value:
messages = data.get("messages")
if messages:
file_ids = (
self.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
messages
)
)
if file_ids:
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
)
data["model_file_id_mapping"] = model_file_id_mapping
return data
def get_file_ids_and_decode_b64_to_unified_uid_from_messages(
self, messages: List[AllMessageValues]
) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
file_id
)
)
file_object_file_field[
"file_id"
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
file_id
)
return file_ids
@staticmethod
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
@staticmethod
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
# Add padding back if needed
padded = b64_uid + "=" * (-len(b64_uid) % 4)
# Decode from base64
try:
decoded = base64.urlsafe_b64decode(padded).decode()
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
return decoded
else:
return False
except Exception:
return False
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
async def get_model_file_id_mapping(
self, file_ids: List[str], litellm_parent_otel_span: Span
) -> dict:
"""
Get model-specific file IDs for a list of proxy file IDs.
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
1. Get all the litellm_proxy/ file_ids from the messages
2. For each file_id, search for cache keys matching the pattern file_id:*
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
Example:
{
"litellm_proxy/file_id": {
"model_id": "model_file_id"
}
}
"""
file_id_mapping: Dict[str, Dict[str, str]] = {}
litellm_managed_file_ids = []
for file_id in file_ids:
## CHECK IF FILE ID IS MANAGED BY LITELM
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
if is_base64_unified_file_id:
litellm_managed_file_ids.append(is_base64_unified_file_id)
elif file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
litellm_managed_file_ids.append(file_id)
if litellm_managed_file_ids:
# Get all cache keys matching the pattern file_id:*
for file_id in litellm_managed_file_ids:
# Search for any cache key starting with this file_id
cached_values = cast(
Dict[str, str],
await self.internal_usage_cache.async_get_cache(
key=file_id, litellm_parent_otel_span=litellm_parent_otel_span
),
)
if cached_values:
file_id_mapping[file_id] = cached_values
return file_id_mapping
async def async_post_call_success_hook(
self,
data: Dict,
user_api_key_dict: UserAPIKeyAuth,
response: LLMResponseTypes,
) -> Any:
if isinstance(response, OpenAIFileObject):
asyncio.create_task(
self.store_unified_file_id(
response.id, response, user_api_key_dict.parent_otel_span
)
)
return None
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
create_file_request: CreateFileRequest,
purpose: OpenAIFilesPurpose,
internal_usage_cache: InternalUsageCache,
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST
file_data = extract_file_data(create_file_request["file"])
file_type = file_data["content_type"]
unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
file_type, str(uuid.uuid4())
)
# Convert to URL-safe base64 and strip padding
base64_unified_file_id = (
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
)
## CREATE RESPONSE OBJECT
response = OpenAIFileObject(
id=base64_unified_file_id,
object="file",
purpose=cast(OpenAIFilesPurpose, purpose),
created_at=file_objects[0].created_at,
bytes=file_objects[0].bytes,
filename=str(datetime.now().timestamp()),
status="uploaded",
)
## STORE RESPONSE IN DB + CACHE
stored_values: Dict[str, str] = {}
for file_object in file_objects:
model_id = file_object._hidden_params.get("model_id")
if model_id is None:
verbose_logger.warning(
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
)
continue
file_id = file_object.id
stored_values[model_id] = file_id
await internal_usage_cache.async_set_cache(
key=unified_file_id,
value=stored_values,
litellm_parent_otel_span=litellm_parent_otel_span,
)
return response
async def afile_retrieve(
self, file_id: str, litellm_parent_otel_span: Optional[Span]
) -> OpenAIFileObject:
stored_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_list(
self, custom_llm_provider: str, **data: Dict
) -> List[OpenAIFileObject]:
return []
async def afile_delete(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> OpenAIFileObject:
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
for model_id, file_id in specific_model_file_id_mapping.items():
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
stored_file_object = await self.delete_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")