diff --git a/litellm/proxy/hooks/managed_files.py b/litellm/proxy/hooks/managed_files.py index 7fbcf192e3..bf70c1f1a0 100644 --- a/litellm/proxy/hooks/managed_files.py +++ b/litellm/proxy/hooks/managed_files.py @@ -336,7 +336,10 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger): raise Exception(f"LiteLLM Managed File object with id={file_id} not found") async def afile_list( - self, custom_llm_provider: str, **data: Dict + self, + purpose: Optional[OpenAIFilesPurpose], + litellm_parent_otel_span: Optional[Span], + **data: Dict, ) -> List[OpenAIFileObject]: return [] @@ -347,6 +350,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger): llm_router: Router, **data: Dict, ) -> OpenAIFileObject: + file_id = self.convert_b64_uid_to_unified_uid(file_id) model_file_id_mapping = await self.get_model_file_id_mapping( [file_id], litellm_parent_otel_span ) @@ -362,3 +366,34 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger): return stored_file_object else: raise Exception(f"LiteLLM Managed File object with id={file_id} not found") + + async def afile_content( + self, + file_id: str, + litellm_parent_otel_span: Optional[Span], + llm_router: Router, + **data: Dict, + ) -> str: + """ + Get the content of a file from first model that has it + """ + initial_file_id = file_id + unified_file_id = self.convert_b64_uid_to_unified_uid(file_id) + model_file_id_mapping = await self.get_model_file_id_mapping( + [unified_file_id], litellm_parent_otel_span + ) + specific_model_file_id_mapping = model_file_id_mapping.get(unified_file_id) + if specific_model_file_id_mapping: + exception_dict = {} + for model_id, file_id in specific_model_file_id_mapping.items(): + try: + return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore + except Exception as e: + exception_dict[model_id] = str(e) + raise Exception( + f"LiteLLM Managed File object with id={initial_file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}" + ) + else: + raise Exception( + f"LiteLLM Managed File object with id={initial_file_id} not found" + ) diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index 83a55385d9..604bdd6b42 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -399,6 +399,7 @@ async def get_file_content( from litellm.proxy.proxy_server import ( add_litellm_data_to_request, general_settings, + llm_router, proxy_config, proxy_logging_obj, version, @@ -421,9 +422,40 @@ async def get_file_content( or await get_custom_llm_provider_from_request_body(request=request) or "openai" ) - response = await litellm.afile_content( - custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore + + ## check if file_id is a litellm managed file + is_base64_unified_file_id = ( + _PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id) ) + if is_base64_unified_file_id: + managed_files_obj = cast( + Optional[_PROXY_LiteLLMManagedFiles], + proxy_logging_obj.get_proxy_hook("managed_files"), + ) + if managed_files_obj is None: + raise ProxyException( + message="Managed files hook not found", + type="None", + param="None", + code=500, + ) + if llm_router is None: + raise ProxyException( + message="LLM Router not found", + type="None", + param="None", + code=500, + ) + response = await managed_files_obj.afile_content( + file_id=file_id, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + llm_router=llm_router, + **data, + ) + else: + response = await litellm.afile_content( + custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore + ) ### ALERTING ### asyncio.create_task( diff --git a/litellm/router.py b/litellm/router.py index bf469617ce..8ceb89f251 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -732,6 +732,9 @@ class Router: self.afile_delete = self.factory_function( litellm.afile_delete, call_type="afile_delete" ) + self.afile_content = self.factory_function( + litellm.afile_content, call_type="afile_content" + ) self.responses = self.factory_function(litellm.responses, call_type="responses") def validate_fallbacks(self, fallback_param: Optional[List]): @@ -3070,6 +3073,7 @@ class Router: "aresponses", "responses", "afile_delete", + "afile_content", ] = "assistants", ): """ @@ -3114,7 +3118,12 @@ class Router: return await self._pass_through_moderation_endpoint_factory( original_function=original_function, **kwargs ) - elif call_type in ("anthropic_messages", "aresponses", "afile_delete"): + elif call_type in ( + "anthropic_messages", + "aresponses", + "afile_delete", + "afile_content", + ): return await self._ageneric_api_call_with_fallbacks( original_function=original_function, **kwargs,