mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Support CRUD endpoints for Managed Files (#9924)
* fix(openai.py): ensure openai file object shows up on logs * fix(managed_files.py): return unified file id as b64 str allows retrieve file id to work as expected * fix(managed_files.py): apply decoded file id transformation * fix: add unit test for file id + decode logic * fix: initial commit for litellm_proxy support with CRUD Endpoints * fix(managed_files.py): support retrieve file operation * fix(managed_files.py): support for DELETE endpoint for files * fix(managed_files.py): retrieve file content support supports retrieve file content api from openai * fix: fix linting error * test: update tests * fix: fix linting error * fix(files/main.py): pass litellm params to azure route * test: fix test
This commit is contained in:
parent
3e427e26c9
commit
3ca82c22b6
14 changed files with 783 additions and 86 deletions
|
@ -1,41 +1,114 @@
|
|||
# What is this?
|
||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm import Router, verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_file_data,
|
||||
get_file_ids_from_messages,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionFileObject,
|
||||
CreateFileRequest,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilesPurpose,
|
||||
)
|
||||
from litellm.types.utils import SpecialEnums
|
||||
from litellm.types.utils import LLMResponseTypes, SpecialEnums
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
from litellm.proxy.utils import PrismaClient as _PrismaClient
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
PrismaClient = _PrismaClient
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class BaseFileEndpoints(ABC):
|
||||
@abstractmethod
|
||||
async def afile_retrieve(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_list(
|
||||
self, custom_llm_provider: str, **data: dict
|
||||
) -> List[OpenAIFileObject]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_delete(
|
||||
self, custom_llm_provider: str, file_id: str, **data: dict
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
|
||||
class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
def __init__(
|
||||
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
|
||||
):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
async def store_unified_file_id(
|
||||
self,
|
||||
file_id: str,
|
||||
file_object: OpenAIFileObject,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
) -> None:
|
||||
key = f"litellm_proxy/{file_id}"
|
||||
verbose_logger.info(
|
||||
f"Storing LiteLLM Managed File object with id={file_id} in cache"
|
||||
)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=key,
|
||||
value=file_object,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
async def get_unified_file_id(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||
) -> Optional[OpenAIFileObject]:
|
||||
key = f"litellm_proxy/{file_id}"
|
||||
return await self.internal_usage_cache.async_get_cache(
|
||||
key=key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
async def delete_unified_file_id(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||
) -> OpenAIFileObject:
|
||||
key = f"litellm_proxy/{file_id}"
|
||||
## get old value
|
||||
old_value = await self.internal_usage_cache.async_get_cache(
|
||||
key=key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
if old_value is None or not isinstance(old_value, OpenAIFileObject):
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
## delete old value
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=key,
|
||||
value=None,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
return old_value
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
|
@ -60,15 +133,82 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
if call_type == CallTypes.completion.value:
|
||||
messages = data.get("messages")
|
||||
if messages:
|
||||
file_ids = get_file_ids_from_messages(messages)
|
||||
file_ids = (
|
||||
self.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||
messages
|
||||
)
|
||||
)
|
||||
if file_ids:
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
file_ids, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
|
||||
data["model_file_id_mapping"] = model_file_id_mapping
|
||||
|
||||
return data
|
||||
|
||||
def get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Gets file ids from messages
|
||||
"""
|
||||
file_ids = []
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
continue
|
||||
for c in content:
|
||||
if c["type"] == "file":
|
||||
file_object = cast(ChatCompletionFileObject, c)
|
||||
file_object_file_field = file_object["file"]
|
||||
file_id = file_object_file_field.get("file_id")
|
||||
if file_id:
|
||||
file_ids.append(
|
||||
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
||||
file_id
|
||||
)
|
||||
)
|
||||
file_object_file_field[
|
||||
"file_id"
|
||||
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
||||
file_id
|
||||
)
|
||||
return file_ids
|
||||
|
||||
@staticmethod
|
||||
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = (
|
||||
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
|
||||
)
|
||||
if is_base64_unified_file_id:
|
||||
return is_base64_unified_file_id
|
||||
else:
|
||||
return b64_uid
|
||||
|
||||
@staticmethod
|
||||
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
|
||||
# Add padding back if needed
|
||||
padded = b64_uid + "=" * (-len(b64_uid) % 4)
|
||||
# Decode from base64
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
return decoded
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||
if is_base64_unified_file_id:
|
||||
return is_base64_unified_file_id
|
||||
else:
|
||||
return b64_uid
|
||||
|
||||
async def get_model_file_id_mapping(
|
||||
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||
) -> dict:
|
||||
|
@ -87,12 +227,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||
litellm_managed_file_ids = []
|
||||
|
||||
for file_id in file_ids:
|
||||
## CHECK IF FILE ID IS MANAGED BY LITELM
|
||||
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
|
||||
|
||||
if is_base64_unified_file_id:
|
||||
litellm_managed_file_ids.append(is_base64_unified_file_id)
|
||||
elif file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
litellm_managed_file_ids.append(file_id)
|
||||
|
||||
if litellm_managed_file_ids:
|
||||
|
@ -107,8 +252,24 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
)
|
||||
if cached_values:
|
||||
file_id_mapping[file_id] = cached_values
|
||||
|
||||
return file_id_mapping
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: Dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: LLMResponseTypes,
|
||||
) -> Any:
|
||||
if isinstance(response, OpenAIFileObject):
|
||||
asyncio.create_task(
|
||||
self.store_unified_file_id(
|
||||
response.id, response, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def return_unified_file_id(
|
||||
file_objects: List[OpenAIFileObject],
|
||||
|
@ -126,15 +287,20 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
file_type, str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Convert to URL-safe base64 and strip padding
|
||||
base64_unified_file_id = (
|
||||
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
## CREATE RESPONSE OBJECT
|
||||
## CREATE RESPONSE OBJECT
|
||||
|
||||
response = OpenAIFileObject(
|
||||
id=unified_file_id,
|
||||
id=base64_unified_file_id,
|
||||
object="file",
|
||||
purpose=cast(OpenAIFilesPurpose, purpose),
|
||||
created_at=file_objects[0].created_at,
|
||||
bytes=1234,
|
||||
filename=str(datetime.now().timestamp()),
|
||||
bytes=file_objects[0].bytes,
|
||||
filename=file_objects[0].filename,
|
||||
status="uploaded",
|
||||
)
|
||||
|
||||
|
@ -156,3 +322,77 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
)
|
||||
|
||||
return response
|
||||
|
||||
async def afile_retrieve(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span]
|
||||
) -> OpenAIFileObject:
|
||||
stored_file_object = await self.get_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if stored_file_object:
|
||||
return stored_file_object
|
||||
else:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
|
||||
async def afile_list(
|
||||
self,
|
||||
purpose: Optional[OpenAIFilesPurpose],
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
**data: Dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
return []
|
||||
|
||||
async def afile_delete(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Router,
|
||||
**data: Dict,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = self.convert_b64_uid_to_unified_uid(file_id)
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
[file_id], litellm_parent_otel_span
|
||||
)
|
||||
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
|
||||
if specific_model_file_id_mapping:
|
||||
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
|
||||
|
||||
stored_file_object = await self.delete_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if stored_file_object:
|
||||
return stored_file_object
|
||||
else:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Router,
|
||||
**data: Dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the content of a file from first model that has it
|
||||
"""
|
||||
initial_file_id = file_id
|
||||
unified_file_id = self.convert_b64_uid_to_unified_uid(file_id)
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
[unified_file_id], litellm_parent_otel_span
|
||||
)
|
||||
specific_model_file_id_mapping = model_file_id_mapping.get(unified_file_id)
|
||||
if specific_model_file_id_mapping:
|
||||
exception_dict = {}
|
||||
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||
try:
|
||||
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore
|
||||
except Exception as e:
|
||||
exception_dict[model_id] = str(e)
|
||||
raise Exception(
|
||||
f"LiteLLM Managed File object with id={initial_file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"LiteLLM Managed File object with id={initial_file_id} not found"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue