mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(managed_files.py): apply decoded file id transformation
This commit is contained in:
parent
4993d9aa50
commit
59fdb7f59a
3 changed files with 68 additions and 32 deletions
|
@ -306,27 +306,6 @@ def get_completion_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
|
||||
"""
|
||||
Gets file ids from messages
|
||||
"""
|
||||
file_ids = []
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
continue
|
||||
for c in content:
|
||||
if c["type"] == "file":
|
||||
file_object = cast(ChatCompletionFileObject, c)
|
||||
file_object_file_field = file_object["file"]
|
||||
file_id = file_object_file_field.get("file_id")
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
return file_ids
|
||||
|
||||
|
||||
def get_format_from_file_id(file_id: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Gets format from file id
|
||||
|
|
|
@ -9,18 +9,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
|||
from litellm import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_file_data,
|
||||
get_file_ids_from_messages,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionFileObject,
|
||||
CreateFileRequest,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilesPurpose,
|
||||
)
|
||||
from litellm.types.utils import SpecialEnums
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -62,19 +60,54 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
if call_type == CallTypes.completion.value:
|
||||
messages = data.get("messages")
|
||||
if messages:
|
||||
file_ids = get_file_ids_from_messages(messages)
|
||||
file_ids = (
|
||||
self.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||
messages
|
||||
)
|
||||
)
|
||||
if file_ids:
|
||||
file_ids = [
|
||||
self.convert_b64_uid_to_unified_uid(file_id)
|
||||
for file_id in file_ids
|
||||
]
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
file_ids, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
|
||||
data["model_file_id_mapping"] = model_file_id_mapping
|
||||
|
||||
return data
|
||||
|
||||
def get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Gets file ids from messages
|
||||
"""
|
||||
file_ids = []
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
continue
|
||||
for c in content:
|
||||
if c["type"] == "file":
|
||||
file_object = cast(ChatCompletionFileObject, c)
|
||||
file_object_file_field = file_object["file"]
|
||||
file_id = file_object_file_field.get("file_id")
|
||||
if file_id:
|
||||
file_ids.append(
|
||||
self._convert_b64_uid_to_unified_uid(file_id)
|
||||
)
|
||||
file_object_file_field[
|
||||
"file_id"
|
||||
] = self._convert_b64_uid_to_unified_uid(file_id)
|
||||
return file_ids
|
||||
|
||||
def _convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||
if is_base64_unified_file_id:
|
||||
return is_base64_unified_file_id
|
||||
else:
|
||||
return b64_uid
|
||||
|
||||
def _is_base64_encoded_unified_file_id(
|
||||
self, b64_uid: str
|
||||
) -> Union[str, Literal[False]]:
|
||||
|
@ -115,12 +148,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||
litellm_managed_file_ids = []
|
||||
|
||||
for file_id in file_ids:
|
||||
## CHECK IF FILE ID IS MANAGED BY LITELM
|
||||
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
|
||||
|
||||
if is_base64_unified_file_id:
|
||||
litellm_managed_file_ids.append(is_base64_unified_file_id)
|
||||
elif file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
litellm_managed_file_ids.append(file_id)
|
||||
|
||||
if litellm_managed_file_ids:
|
||||
|
|
19
tests/litellm/proxy/hooks/test_managed_files.py
Normal file
19
tests/litellm/proxy/hooks/test_managed_files.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
|
||||
|
||||
# def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages():
|
||||
# proxy_managed_files = _PROXY_LiteLLMManagedFiles()
|
||||
# messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
Loading…
Add table
Add a link
Reference in a new issue