fix(managed_files.py): return unified file id as b64 str

allows retrieve file id to work as expected
This commit is contained in:
Krrish Dholakia 2025-04-11 11:40:49 -07:00
parent b460025e18
commit 4993d9aa50
3 changed files with 157 additions and 5 deletions

View file

@ -22,6 +22,8 @@ from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMExcepti
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
ChatCompletionFileObjectFile,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
)
@ -188,6 +190,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
message_content = message.get("content")
if message_content and isinstance(message_content, list):
for content_item in message_content:
litellm_specific_params = {"format"}
if content_item.get("type") == "image_url":
content_item = cast(ChatCompletionImageObject, content_item)
if isinstance(content_item["image_url"], str):
@ -195,7 +198,6 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
"url": content_item["image_url"],
}
elif isinstance(content_item["image_url"], dict):
litellm_specific_params = {"format"}
new_image_url_obj = ChatCompletionImageUrlObject(
**{ # type: ignore
k: v
@ -204,6 +206,17 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
}
)
content_item["image_url"] = new_image_url_obj
elif content_item.get("type") == "file":
content_item = cast(ChatCompletionFileObject, content_item)
file_obj = content_item["file"]
new_file_obj = ChatCompletionFileObjectFile(
**{ # type: ignore
k: v
for k, v in file_obj.items()
if k not in litellm_specific_params
}
)
content_item["file"] = new_file_obj
return messages
def transform_request(
@ -403,4 +416,4 @@ class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
choices=chunk["choices"],
)
except Exception as e:
raise e
raise e

View file

@ -1,6 +1,7 @@
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import base64
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
@ -19,6 +20,7 @@ from litellm.types.llms.openai import (
OpenAIFilesPurpose,
)
from litellm.types.utils import SpecialEnums
from litellm.utils import is_base64_encoded
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -62,6 +64,10 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
if messages:
file_ids = get_file_ids_from_messages(messages)
if file_ids:
file_ids = [
self.convert_b64_uid_to_unified_uid(file_id)
for file_id in file_ids
]
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
)
@ -69,6 +75,28 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
return data
def _is_base64_encoded_unified_file_id(
self, b64_uid: str
) -> Union[str, Literal[False]]:
# Add padding back if needed
padded = b64_uid + "=" * (-len(b64_uid) % 4)
# Decode from base64
try:
decoded = base64.urlsafe_b64decode(padded).decode()
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
return decoded
else:
return False
except Exception:
return False
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
async def get_model_file_id_mapping(
self, file_ids: List[str], litellm_parent_otel_span: Span
) -> dict:
@ -107,6 +135,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
)
if cached_values:
file_id_mapping[file_id] = cached_values
return file_id_mapping
@staticmethod
@ -126,14 +155,19 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
file_type, str(uuid.uuid4())
)
# Convert to URL-safe base64 and strip padding
base64_unified_file_id = (
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
)
## CREATE RESPONSE OBJECT
## CREATE RESPONSE OBJECT
response = OpenAIFileObject(
id=unified_file_id,
id=base64_unified_file_id,
object="file",
purpose=cast(OpenAIFilesPurpose, purpose),
created_at=file_objects[0].created_at,
bytes=1234,
bytes=file_objects[0].bytes,
filename=str(datetime.now().timestamp()),
status="uploaded",
)

View file

@ -304,3 +304,108 @@ def test_create_file_and_call_chat_completion_e2e(
finally:
# Stop the mock
mock.stop()
def test_create_file_for_each_model(
mocker: MockerFixture, monkeypatch, llm_router: Router
):
"""
Test that create_file_for_each_model creates files for each target model and returns a unified file ID
"""
import asyncio
from litellm import CreateFileRequest
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.openai_files_endpoints.files_endpoints import (
create_file_for_each_model,
)
from litellm.proxy.utils import ProxyLogging
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
# Setup proxy logging
proxy_logging_obj = ProxyLogging(
user_api_key_cache=DualCache(default_in_memory_ttl=1)
)
proxy_logging_obj._add_proxy_hooks(llm_router)
monkeypatch.setattr(
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_obj
)
# Mock user API key dict
user_api_key_dict = UserAPIKeyAuth(
api_key="test-key",
user_role=LitellmUserRoles.INTERNAL_USER,
user_id="test-user",
team_id="test-team",
team_alias="test-team-alias",
parent_otel_span=None,
)
# Create test file request
test_file_content = b"test file content"
test_file = ("test.txt", test_file_content, "text/plain")
_create_file_request = CreateFileRequest(file=test_file, purpose="user_data")
# Mock the router's acreate_file method
mock_file_response = OpenAIFileObject(
id="test-file-id",
object="file",
bytes=123,
created_at=1234567890,
filename="test.txt",
purpose="user_data",
status="uploaded",
)
mock_file_response._hidden_params = {"model_id": "test-model-id"}
mocker.patch.object(llm_router, "acreate_file", return_value=mock_file_response)
# Call the function
target_model_names_list = ["azure-gpt-3-5-turbo", "gpt-3.5-turbo"]
response = asyncio.run(
create_file_for_each_model(
llm_router=llm_router,
_create_file_request=_create_file_request,
target_model_names_list=target_model_names_list,
purpose="user_data",
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
)
)
# Verify the response
assert isinstance(response, OpenAIFileObject)
assert response.id is not None
assert response.purpose == "user_data"
assert response.filename == "test.txt"
# Verify acreate_file was called for each model
assert llm_router.acreate_file.call_count == len(target_model_names_list)
# Get all calls made to acreate_file
calls = llm_router.acreate_file.call_args_list
# Verify Azure call
azure_call_found = False
for call in calls:
kwargs = call.kwargs
if (
kwargs.get("model") == "azure-gpt-3-5-turbo"
and kwargs.get("file") == test_file
and kwargs.get("purpose") == "user_data"
):
azure_call_found = True
break
assert azure_call_found, "Azure call not found with expected parameters"
# Verify OpenAI call
openai_call_found = False
for call in calls:
kwargs = call.kwargs
if (
kwargs.get("model") == "gpt-3.5-turbo"
and kwargs.get("file") == test_file
and kwargs.get("purpose") == "user_data"
):
openai_call_found = True
break
assert openai_call_found, "OpenAI call not found with expected parameters"