mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(managed_files.py): apply decoded file id transformation
This commit is contained in:
parent
4993d9aa50
commit
59fdb7f59a
3 changed files with 68 additions and 32 deletions
|
@ -306,27 +306,6 @@ def get_completion_messages(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
|
|
||||||
"""
|
|
||||||
Gets file ids from messages
|
|
||||||
"""
|
|
||||||
file_ids = []
|
|
||||||
for message in messages:
|
|
||||||
if message.get("role") == "user":
|
|
||||||
content = message.get("content")
|
|
||||||
if content:
|
|
||||||
if isinstance(content, str):
|
|
||||||
continue
|
|
||||||
for c in content:
|
|
||||||
if c["type"] == "file":
|
|
||||||
file_object = cast(ChatCompletionFileObject, c)
|
|
||||||
file_object_file_field = file_object["file"]
|
|
||||||
file_id = file_object_file_field.get("file_id")
|
|
||||||
if file_id:
|
|
||||||
file_ids.append(file_id)
|
|
||||||
return file_ids
|
|
||||||
|
|
||||||
|
|
||||||
def get_format_from_file_id(file_id: Optional[str]) -> Optional[str]:
|
def get_format_from_file_id(file_id: Optional[str]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Gets format from file id
|
Gets format from file id
|
||||||
|
|
|
@ -9,18 +9,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||||
extract_file_data,
|
|
||||||
get_file_ids_from_messages,
|
|
||||||
)
|
|
||||||
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionFileObject,
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
OpenAIFileObject,
|
OpenAIFileObject,
|
||||||
OpenAIFilesPurpose,
|
OpenAIFilesPurpose,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import SpecialEnums
|
from litellm.types.utils import SpecialEnums
|
||||||
from litellm.utils import is_base64_encoded
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -62,19 +60,54 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
if call_type == CallTypes.completion.value:
|
if call_type == CallTypes.completion.value:
|
||||||
messages = data.get("messages")
|
messages = data.get("messages")
|
||||||
if messages:
|
if messages:
|
||||||
file_ids = get_file_ids_from_messages(messages)
|
file_ids = (
|
||||||
|
self.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
)
|
||||||
if file_ids:
|
if file_ids:
|
||||||
file_ids = [
|
|
||||||
self.convert_b64_uid_to_unified_uid(file_id)
|
|
||||||
for file_id in file_ids
|
|
||||||
]
|
|
||||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||||
file_ids, user_api_key_dict.parent_otel_span
|
file_ids, user_api_key_dict.parent_otel_span
|
||||||
)
|
)
|
||||||
|
|
||||||
data["model_file_id_mapping"] = model_file_id_mapping
|
data["model_file_id_mapping"] = model_file_id_mapping
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Gets file ids from messages
|
||||||
|
"""
|
||||||
|
file_ids = []
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "user":
|
||||||
|
content = message.get("content")
|
||||||
|
if content:
|
||||||
|
if isinstance(content, str):
|
||||||
|
continue
|
||||||
|
for c in content:
|
||||||
|
if c["type"] == "file":
|
||||||
|
file_object = cast(ChatCompletionFileObject, c)
|
||||||
|
file_object_file_field = file_object["file"]
|
||||||
|
file_id = file_object_file_field.get("file_id")
|
||||||
|
if file_id:
|
||||||
|
file_ids.append(
|
||||||
|
self._convert_b64_uid_to_unified_uid(file_id)
|
||||||
|
)
|
||||||
|
file_object_file_field[
|
||||||
|
"file_id"
|
||||||
|
] = self._convert_b64_uid_to_unified_uid(file_id)
|
||||||
|
return file_ids
|
||||||
|
|
||||||
|
def _convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||||
|
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
return is_base64_unified_file_id
|
||||||
|
else:
|
||||||
|
return b64_uid
|
||||||
|
|
||||||
def _is_base64_encoded_unified_file_id(
|
def _is_base64_encoded_unified_file_id(
|
||||||
self, b64_uid: str
|
self, b64_uid: str
|
||||||
) -> Union[str, Literal[False]]:
|
) -> Union[str, Literal[False]]:
|
||||||
|
@ -115,12 +148,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||||
litellm_managed_file_ids = []
|
litellm_managed_file_ids = []
|
||||||
|
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
## CHECK IF FILE ID IS MANAGED BY LITELM
|
## CHECK IF FILE ID IS MANAGED BY LITELM
|
||||||
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
|
||||||
|
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
litellm_managed_file_ids.append(is_base64_unified_file_id)
|
||||||
|
elif file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||||
litellm_managed_file_ids.append(file_id)
|
litellm_managed_file_ids.append(file_id)
|
||||||
|
|
||||||
if litellm_managed_file_ids:
|
if litellm_managed_file_ids:
|
||||||
|
|
19
tests/litellm/proxy/hooks/test_managed_files.py
Normal file
19
tests/litellm/proxy/hooks/test_managed_files.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
|
||||||
|
|
||||||
|
# def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles()
|
||||||
|
# messages = [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": [
|
Loading…
Add table
Add a link
Reference in a new issue