mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(managed_files.py): return unified file id as b64 str
allows retrieve file id to work as expected
This commit is contained in:
parent
b460025e18
commit
4993d9aa50
3 changed files with 157 additions and 5 deletions
|
@ -1,6 +1,7 @@
|
|||
# What is this?
|
||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
||||
|
@ -19,6 +20,7 @@ from litellm.types.llms.openai import (
|
|||
OpenAIFilesPurpose,
|
||||
)
|
||||
from litellm.types.utils import SpecialEnums
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -62,6 +64,10 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
if messages:
|
||||
file_ids = get_file_ids_from_messages(messages)
|
||||
if file_ids:
|
||||
file_ids = [
|
||||
self.convert_b64_uid_to_unified_uid(file_id)
|
||||
for file_id in file_ids
|
||||
]
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
file_ids, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
|
@ -69,6 +75,28 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
|
||||
return data
|
||||
|
||||
def _is_base64_encoded_unified_file_id(
|
||||
self, b64_uid: str
|
||||
) -> Union[str, Literal[False]]:
|
||||
# Add padding back if needed
|
||||
padded = b64_uid + "=" * (-len(b64_uid) % 4)
|
||||
# Decode from base64
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
return decoded
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||
if is_base64_unified_file_id:
|
||||
return is_base64_unified_file_id
|
||||
else:
|
||||
return b64_uid
|
||||
|
||||
async def get_model_file_id_mapping(
|
||||
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||
) -> dict:
|
||||
|
@ -107,6 +135,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
)
|
||||
if cached_values:
|
||||
file_id_mapping[file_id] = cached_values
|
||||
|
||||
return file_id_mapping
|
||||
|
||||
@staticmethod
|
||||
|
@ -126,14 +155,19 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
file_type, str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Convert to URL-safe base64 and strip padding
|
||||
base64_unified_file_id = (
|
||||
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
## CREATE RESPONSE OBJECT
|
||||
## CREATE RESPONSE OBJECT
|
||||
|
||||
response = OpenAIFileObject(
|
||||
id=unified_file_id,
|
||||
id=base64_unified_file_id,
|
||||
object="file",
|
||||
purpose=cast(OpenAIFilesPurpose, purpose),
|
||||
created_at=file_objects[0].created_at,
|
||||
bytes=1234,
|
||||
bytes=file_objects[0].bytes,
|
||||
filename=str(datetime.now().timestamp()),
|
||||
status="uploaded",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue