fix: initial commit for litellm_proxy support with CRUD Endpoints

This commit is contained in:
Krrish Dholakia 2025-04-11 12:57:54 -07:00
parent 6b04b48b17
commit 7fff83e441
6 changed files with 197 additions and 9 deletions

View file

@ -20,8 +20,7 @@ from litellm.types.integrations.argilla import ArgillaItem
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
from litellm.types.utils import (
AdapterCompletionStreamWrapper,
EmbeddingResponse,
ImageResponse,
LLMResponseTypes,
ModelResponse,
ModelResponseStream,
StandardCallbackDynamicParams,
@ -223,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
response: LLMResponseTypes,
) -> Any:
pass

View file

@ -1,10 +1,12 @@
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import asyncio
import base64
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from litellm import verbose_logger
from litellm.caching.caching import DualCache
@ -18,24 +20,58 @@ from litellm.types.llms.openai import (
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import SpecialEnums
from litellm.types.utils import LLMResponseTypes, SpecialEnums
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.proxy.utils import PrismaClient as _PrismaClient
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
PrismaClient = _PrismaClient
else:
Span = Any
InternalUsageCache = Any
PrismaClient = Any
class _PROXY_LiteLLMManagedFiles(CustomLogger):
class BaseFileEndpoints(ABC):
@abstractmethod
async def afile_retrieve(
self, custom_llm_provider: str, file_id: str, **data: dict
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_list(
self, custom_llm_provider: str, **data: dict
) -> List[OpenAIFileObject]:
pass
@abstractmethod
async def afile_delete(
self, custom_llm_provider: str, file_id: str, **data: dict
) -> OpenAIFileObject:
pass
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
def __init__(
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
):
self.internal_usage_cache = internal_usage_cache
self.prisma_client = prisma_client
async def store_unified_file_id(
self, file_id: str, file_object: OpenAIFileObject
) -> None:
pass
async def get_unified_file_id(self, file_id: str) -> Optional[OpenAIFileObject]:
return None
async def async_pre_call_hook(
self,
@ -176,6 +212,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
return file_id_mapping
async def post_call_success_hook(
self,
data: Dict,
response: LLMResponseTypes,
user_api_key_dict: UserAPIKeyAuth,
) -> Any:
if isinstance(response, OpenAIFileObject):
asyncio.create_task(self.store_unified_file_id(response.id, response))
return None
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
@ -228,3 +275,22 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
)
return response
async def afile_retrieve(
self, custom_llm_provider: str, file_id: str, **data: Dict
) -> OpenAIFileObject:
stored_file_object = await self.get_unified_file_id(file_id)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"File object with id={file_id} not found")
async def afile_list(
self, custom_llm_provider: str, **data: Dict
) -> List[OpenAIFileObject]:
return []
async def afile_delete(
self, custom_llm_provider: str, file_id: str, **data: Dict
) -> OpenAIFileObject:
raise NotImplementedError("afile_delete not implemented")

View file

@ -311,6 +311,13 @@ async def create_file(
)
)
## POST CALL HOOKS ###
_response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
if _response is not None and isinstance(_response, OpenAIFileObject):
response = _response
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
@ -539,6 +546,8 @@ async def get_file(
version=version,
proxy_config=proxy_config,
)
## check if file_id is a litellm managed file
response = await litellm.afile_retrieve(
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)

View file

@ -85,7 +85,7 @@ from litellm.proxy.hooks.parallel_request_limiter import (
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -940,7 +940,7 @@ class ProxyLogging:
async def post_call_success_hook(
self,
data: dict,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
response: LLMResponseTypes,
user_api_key_dict: UserAPIKeyAuth,
):
"""
@ -948,6 +948,9 @@ class ProxyLogging:
Covers:
1. /chat/completions
2. /embeddings
3. /image/generation
4. /files
"""
for callback in litellm.callbacks:

View file

@ -34,6 +34,7 @@ from .llms.openai import (
ChatCompletionUsageBlock,
FileSearchTool,
OpenAIChatCompletionChunk,
OpenAIFileObject,
OpenAIRealtimeStreamList,
WebSearchOptions,
)
@ -2226,3 +2227,8 @@ class ExtractedFileData(TypedDict):
class SpecialEnums(Enum):
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy"
LITELLM_MANAGED_FILE_COMPLETE_STR = "litellm_proxy:{};unified_id,{}"
LLMResponseTypes = Union[
ModelResponse, EmbeddingResponse, ImageResponse, OpenAIFileObject
]

View file

@ -43,3 +43,108 @@ def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages():
assert messages[0]["content"][1]["file"]["file_id"].startswith(
SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
)
# def test_list_managed_files():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create some test files
# file1 = proxy_managed_files.create_file(
# file=("test1.txt", b"test content 1", "text/plain"),
# purpose="assistants"
# )
# file2 = proxy_managed_files.create_file(
# file=("test2.pdf", b"test content 2", "application/pdf"),
# purpose="assistants"
# )
# # List all files
# files = proxy_managed_files.list_files()
# # Verify response
# assert len(files) == 2
# assert all(f.id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value) for f in files)
# assert any(f.filename == "test1.txt" for f in files)
# assert any(f.filename == "test2.pdf" for f in files)
# assert all(f.purpose == "assistants" for f in files)
# def test_retrieve_managed_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create a test file
# test_content = b"test content for retrieve"
# created_file = proxy_managed_files.create_file(
# file=("test.txt", test_content, "text/plain"),
# purpose="assistants"
# )
# # Retrieve the file
# retrieved_file = proxy_managed_files.retrieve_file(created_file.id)
# # Verify response
# assert retrieved_file.id == created_file.id
# assert retrieved_file.filename == "test.txt"
# assert retrieved_file.purpose == "assistants"
# assert retrieved_file.bytes == len(test_content)
# assert retrieved_file.status == "uploaded"
# def test_delete_managed_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create a test file
# created_file = proxy_managed_files.create_file(
# file=("test.txt", b"test content", "text/plain"),
# purpose="assistants"
# )
# # Delete the file
# deleted_file = proxy_managed_files.delete_file(created_file.id)
# # Verify deletion
# assert deleted_file.id == created_file.id
# assert deleted_file.deleted == True
# # Verify file is no longer retrievable
# with pytest.raises(Exception):
# proxy_managed_files.retrieve_file(created_file.id)
# # Verify file is not in list
# files = proxy_managed_files.list_files()
# assert created_file.id not in [f.id for f in files]
# def test_retrieve_nonexistent_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Try to retrieve a non-existent file
# with pytest.raises(Exception):
# proxy_managed_files.retrieve_file("nonexistent-file-id")
# def test_delete_nonexistent_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Try to delete a non-existent file
# with pytest.raises(Exception):
# proxy_managed_files.delete_file("nonexistent-file-id")
# def test_list_files_with_purpose_filter():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create files with different purposes
# file1 = proxy_managed_files.create_file(
# file=("test1.txt", b"test content 1", "text/plain"),
# purpose="assistants"
# )
# file2 = proxy_managed_files.create_file(
# file=("test2.pdf", b"test content 2", "application/pdf"),
# purpose="batch"
# )
# # List files with purpose filter
# assistant_files = proxy_managed_files.list_files(purpose="assistants")
# batch_files = proxy_managed_files.list_files(purpose="batch")
# # Verify filtering
# assert len(assistant_files) == 1
# assert len(batch_files) == 1
# assert assistant_files[0].id == file1.id
# assert batch_files[0].id == file2.id