fix: initial commit for litellm_proxy support with CRUD Endpoints

This commit is contained in:
Krrish Dholakia 2025-04-11 12:57:54 -07:00
parent 6b04b48b17
commit 7fff83e441
6 changed files with 197 additions and 9 deletions

View file

@ -1,10 +1,12 @@
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import asyncio
import base64
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from litellm import verbose_logger
from litellm.caching.caching import DualCache
@ -18,24 +20,58 @@ from litellm.types.llms.openai import (
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import SpecialEnums
from litellm.types.utils import LLMResponseTypes, SpecialEnums
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.proxy.utils import PrismaClient as _PrismaClient
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
PrismaClient = _PrismaClient
else:
Span = Any
InternalUsageCache = Any
PrismaClient = Any
class _PROXY_LiteLLMManagedFiles(CustomLogger):
class BaseFileEndpoints(ABC):
@abstractmethod
async def afile_retrieve(
self, custom_llm_provider: str, file_id: str, **data: dict
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_list(
self, custom_llm_provider: str, **data: dict
) -> List[OpenAIFileObject]:
pass
@abstractmethod
async def afile_delete(
self, custom_llm_provider: str, file_id: str, **data: dict
) -> OpenAIFileObject:
pass
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
def __init__(
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
):
self.internal_usage_cache = internal_usage_cache
self.prisma_client = prisma_client
async def store_unified_file_id(
self, file_id: str, file_object: OpenAIFileObject
) -> None:
pass
async def get_unified_file_id(self, file_id: str) -> Optional[OpenAIFileObject]:
return None
async def async_pre_call_hook(
self,
@ -176,6 +212,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
return file_id_mapping
async def post_call_success_hook(
self,
data: Dict,
response: LLMResponseTypes,
user_api_key_dict: UserAPIKeyAuth,
) -> Any:
if isinstance(response, OpenAIFileObject):
asyncio.create_task(self.store_unified_file_id(response.id, response))
return None
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
@ -228,3 +275,22 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
)
return response
async def afile_retrieve(
self, custom_llm_provider: str, file_id: str, **data: Dict
) -> OpenAIFileObject:
stored_file_object = await self.get_unified_file_id(file_id)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"File object with id={file_id} not found")
async def afile_list(
self, custom_llm_provider: str, **data: Dict
) -> List[OpenAIFileObject]:
return []
async def afile_delete(
self, custom_llm_provider: str, file_id: str, **data: Dict
) -> OpenAIFileObject:
raise NotImplementedError("afile_delete not implemented")