diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index ddb8094285..a7471f32f4 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -20,8 +20,7 @@ from litellm.types.integrations.argilla import ArgillaItem from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest from litellm.types.utils import ( AdapterCompletionStreamWrapper, - EmbeddingResponse, - ImageResponse, + LLMResponseTypes, ModelResponse, ModelResponseStream, StandardCallbackDynamicParams, @@ -223,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac self, data: dict, user_api_key_dict: UserAPIKeyAuth, - response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], + response: LLMResponseTypes, ) -> Any: pass diff --git a/litellm/proxy/hooks/managed_files.py b/litellm/proxy/hooks/managed_files.py index 2a459ad4f6..6e78c31efe 100644 --- a/litellm/proxy/hooks/managed_files.py +++ b/litellm/proxy/hooks/managed_files.py @@ -1,10 +1,12 @@ # What is this? ## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id +import asyncio import base64 import uuid +from abc import ABC, abstractmethod from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast from litellm import verbose_logger from litellm.caching.caching import DualCache @@ -18,24 +20,58 @@ from litellm.types.llms.openai import ( OpenAIFileObject, OpenAIFilesPurpose, ) -from litellm.types.utils import SpecialEnums +from litellm.types.utils import LLMResponseTypes, SpecialEnums if TYPE_CHECKING: from opentelemetry.trace import Span as _Span from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache + from litellm.proxy.utils import PrismaClient as _PrismaClient Span = Union[_Span, Any] InternalUsageCache = _InternalUsageCache + PrismaClient = _PrismaClient else: Span = Any InternalUsageCache = Any + PrismaClient = Any -class _PROXY_LiteLLMManagedFiles(CustomLogger): +class BaseFileEndpoints(ABC): + @abstractmethod + async def afile_retrieve( + self, custom_llm_provider: str, file_id: str, **data: dict + ) -> OpenAIFileObject: + pass + + @abstractmethod + async def afile_list( + self, custom_llm_provider: str, **data: dict + ) -> List[OpenAIFileObject]: + pass + + @abstractmethod + async def afile_delete( + self, custom_llm_provider: str, file_id: str, **data: dict + ) -> OpenAIFileObject: + pass + + +class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): # Class variables or attributes - def __init__(self, internal_usage_cache: InternalUsageCache): + def __init__( + self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient + ): self.internal_usage_cache = internal_usage_cache + self.prisma_client = prisma_client + + async def store_unified_file_id( + self, file_id: str, file_object: OpenAIFileObject + ) -> None: + pass + + async def get_unified_file_id(self, file_id: str) -> Optional[OpenAIFileObject]: + return None async def async_pre_call_hook( self, @@ -176,6 +212,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger): return file_id_mapping + async def post_call_success_hook( + self, + data: Dict, + response: LLMResponseTypes, + user_api_key_dict: UserAPIKeyAuth, + ) -> Any: + if isinstance(response, OpenAIFileObject): + asyncio.create_task(self.store_unified_file_id(response.id, response)) + + return None + @staticmethod async def return_unified_file_id( file_objects: List[OpenAIFileObject], @@ -228,3 +275,22 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger): ) return response + + async def afile_retrieve( + self, custom_llm_provider: str, file_id: str, **data: Dict + ) -> OpenAIFileObject: + stored_file_object = await self.get_unified_file_id(file_id) + if stored_file_object: + return stored_file_object + else: + raise Exception(f"File object with id={file_id} not found") + + async def afile_list( + self, custom_llm_provider: str, **data: Dict + ) -> List[OpenAIFileObject]: + return [] + + async def afile_delete( + self, custom_llm_provider: str, file_id: str, **data: Dict + ) -> OpenAIFileObject: + raise NotImplementedError("afile_delete not implemented") diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index 77458d9889..ba1cabd3ef 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -311,6 +311,13 @@ async def create_file( ) ) + ## POST CALL HOOKS ### + _response = await proxy_logging_obj.post_call_success_hook( + data=data, user_api_key_dict=user_api_key_dict, response=response + ) + if _response is not None and isinstance(_response, OpenAIFileObject): + response = _response + ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" @@ -539,6 +546,8 @@ async def get_file( version=version, proxy_config=proxy_config, ) + + ## check if file_id is a litellm managed file response = await litellm.afile_retrieve( custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index b1a32b3c45..f9ea3fbac7 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -85,7 +85,7 @@ from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.secret_managers.main import str_to_bool from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES -from litellm.types.utils import CallTypes, LoggedLiteLLMParams +from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -940,7 +940,7 @@ class ProxyLogging: async def post_call_success_hook( self, data: dict, - response: Union[ModelResponse, EmbeddingResponse, ImageResponse], + response: LLMResponseTypes, user_api_key_dict: UserAPIKeyAuth, ): """ @@ -948,6 +948,9 @@ class ProxyLogging: Covers: 1. /chat/completions + 2. /embeddings + 3. /image/generation + 4. /files """ for callback in litellm.callbacks: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 1bbec44b82..6511c5a133 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -34,6 +34,7 @@ from .llms.openai import ( ChatCompletionUsageBlock, FileSearchTool, OpenAIChatCompletionChunk, + OpenAIFileObject, OpenAIRealtimeStreamList, WebSearchOptions, ) @@ -2226,3 +2227,8 @@ class ExtractedFileData(TypedDict): class SpecialEnums(Enum): LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy" LITELLM_MANAGED_FILE_COMPLETE_STR = "litellm_proxy:{};unified_id,{}" + + +LLMResponseTypes = Union[ + ModelResponse, EmbeddingResponse, ImageResponse, OpenAIFileObject +] diff --git a/tests/litellm/proxy/hooks/test_managed_files.py b/tests/litellm/proxy/hooks/test_managed_files.py index 6f675fe843..006a99d324 100644 --- a/tests/litellm/proxy/hooks/test_managed_files.py +++ b/tests/litellm/proxy/hooks/test_managed_files.py @@ -43,3 +43,108 @@ def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages(): assert messages[0]["content"][1]["file"]["file_id"].startswith( SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value ) + + +# def test_list_managed_files(): +# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) + +# # Create some test files +# file1 = proxy_managed_files.create_file( +# file=("test1.txt", b"test content 1", "text/plain"), +# purpose="assistants" +# ) +# file2 = proxy_managed_files.create_file( +# file=("test2.pdf", b"test content 2", "application/pdf"), +# purpose="assistants" +# ) + +# # List all files +# files = proxy_managed_files.list_files() + +# # Verify response +# assert len(files) == 2 +# assert all(f.id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value) for f in files) +# assert any(f.filename == "test1.txt" for f in files) +# assert any(f.filename == "test2.pdf" for f in files) +# assert all(f.purpose == "assistants" for f in files) + +# def test_retrieve_managed_file(): +# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) + +# # Create a test file +# test_content = b"test content for retrieve" +# created_file = proxy_managed_files.create_file( +# file=("test.txt", test_content, "text/plain"), +# purpose="assistants" +# ) + +# # Retrieve the file +# retrieved_file = proxy_managed_files.retrieve_file(created_file.id) + +# # Verify response +# assert retrieved_file.id == created_file.id +# assert retrieved_file.filename == "test.txt" +# assert retrieved_file.purpose == "assistants" +# assert retrieved_file.bytes == len(test_content) +# assert retrieved_file.status == "uploaded" + +# def test_delete_managed_file(): +# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) + +# # Create a test file +# created_file = proxy_managed_files.create_file( +# file=("test.txt", b"test content", "text/plain"), +# purpose="assistants" +# ) + +# # Delete the file +# deleted_file = proxy_managed_files.delete_file(created_file.id) + +# # Verify deletion +# assert deleted_file.id == created_file.id +# assert deleted_file.deleted == True + +# # Verify file is no longer retrievable +# with pytest.raises(Exception): +# proxy_managed_files.retrieve_file(created_file.id) + +# # Verify file is not in list +# files = proxy_managed_files.list_files() +# assert created_file.id not in [f.id for f in files] + +# def test_retrieve_nonexistent_file(): +# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) + +# # Try to retrieve a non-existent file +# with pytest.raises(Exception): +# proxy_managed_files.retrieve_file("nonexistent-file-id") + +# def test_delete_nonexistent_file(): +# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) + +# # Try to delete a non-existent file +# with pytest.raises(Exception): +# proxy_managed_files.delete_file("nonexistent-file-id") + +# def test_list_files_with_purpose_filter(): +# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache()) + +# # Create files with different purposes +# file1 = proxy_managed_files.create_file( +# file=("test1.txt", b"test content 1", "text/plain"), +# purpose="assistants" +# ) +# file2 = proxy_managed_files.create_file( +# file=("test2.pdf", b"test content 2", "application/pdf"), +# purpose="batch" +# ) + +# # List files with purpose filter +# assistant_files = proxy_managed_files.list_files(purpose="assistants") +# batch_files = proxy_managed_files.list_files(purpose="batch") + +# # Verify filtering +# assert len(assistant_files) == 1 +# assert len(batch_files) == 1 +# assert assistant_files[0].id == file1.id +# assert batch_files[0].id == file2.id