fix(managed_files.py): apply decoded file id transformation

This commit is contained in:
Krrish Dholakia 2025-04-11 12:18:04 -07:00
parent 4993d9aa50
commit 59fdb7f59a
3 changed files with 68 additions and 32 deletions

View file

@ -306,27 +306,6 @@ def get_completion_messages(
return messages
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(file_id)
return file_ids
def get_format_from_file_id(file_id: Optional[str]) -> Optional[str]:
"""
Gets format from file id

View file

@ -9,18 +9,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
from litellm import verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_file_data,
get_file_ids_from_messages,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
CreateFileRequest,
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import SpecialEnums
from litellm.utils import is_base64_encoded
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -62,19 +60,54 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
if call_type == CallTypes.completion.value:
messages = data.get("messages")
if messages:
file_ids = get_file_ids_from_messages(messages)
file_ids = (
self.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
messages
)
)
if file_ids:
file_ids = [
self.convert_b64_uid_to_unified_uid(file_id)
for file_id in file_ids
]
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
)
data["model_file_id_mapping"] = model_file_id_mapping
return data
def get_file_ids_and_decode_b64_to_unified_uid_from_messages(
self, messages: List[AllMessageValues]
) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(
self._convert_b64_uid_to_unified_uid(file_id)
)
file_object_file_field[
"file_id"
] = self._convert_b64_uid_to_unified_uid(file_id)
return file_ids
def _convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
def _is_base64_encoded_unified_file_id(
self, b64_uid: str
) -> Union[str, Literal[False]]:
@ -115,12 +148,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
}
}
"""
file_id_mapping: Dict[str, Dict[str, str]] = {}
litellm_managed_file_ids = []
for file_id in file_ids:
## CHECK IF FILE ID IS MANAGED BY LITELM
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
if is_base64_unified_file_id:
litellm_managed_file_ids.append(is_base64_unified_file_id)
elif file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
litellm_managed_file_ids.append(file_id)
if litellm_managed_file_ids:

View file

@ -0,0 +1,19 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
# def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles()
# messages = [
# {
# "role": "user",
# "content": [