mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(managed_files.py): return unified file id as b64 str
allows retrieve file id to work as expected
This commit is contained in:
parent
b460025e18
commit
4993d9aa50
3 changed files with 157 additions and 5 deletions
|
@ -22,6 +22,8 @@ from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMExcepti
|
|||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionFileObject,
|
||||
ChatCompletionFileObjectFile,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionImageUrlObject,
|
||||
)
|
||||
|
@ -188,6 +190,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
|||
message_content = message.get("content")
|
||||
if message_content and isinstance(message_content, list):
|
||||
for content_item in message_content:
|
||||
litellm_specific_params = {"format"}
|
||||
if content_item.get("type") == "image_url":
|
||||
content_item = cast(ChatCompletionImageObject, content_item)
|
||||
if isinstance(content_item["image_url"], str):
|
||||
|
@ -195,7 +198,6 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
|||
"url": content_item["image_url"],
|
||||
}
|
||||
elif isinstance(content_item["image_url"], dict):
|
||||
litellm_specific_params = {"format"}
|
||||
new_image_url_obj = ChatCompletionImageUrlObject(
|
||||
**{ # type: ignore
|
||||
k: v
|
||||
|
@ -204,6 +206,17 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
|||
}
|
||||
)
|
||||
content_item["image_url"] = new_image_url_obj
|
||||
elif content_item.get("type") == "file":
|
||||
content_item = cast(ChatCompletionFileObject, content_item)
|
||||
file_obj = content_item["file"]
|
||||
new_file_obj = ChatCompletionFileObjectFile(
|
||||
**{ # type: ignore
|
||||
k: v
|
||||
for k, v in file_obj.items()
|
||||
if k not in litellm_specific_params
|
||||
}
|
||||
)
|
||||
content_item["file"] = new_file_obj
|
||||
return messages
|
||||
|
||||
def transform_request(
|
||||
|
@ -403,4 +416,4 @@ class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
|
|||
choices=chunk["choices"],
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise e
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# What is this?
|
||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
||||
|
@ -19,6 +20,7 @@ from litellm.types.llms.openai import (
|
|||
OpenAIFilesPurpose,
|
||||
)
|
||||
from litellm.types.utils import SpecialEnums
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -62,6 +64,10 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
if messages:
|
||||
file_ids = get_file_ids_from_messages(messages)
|
||||
if file_ids:
|
||||
file_ids = [
|
||||
self.convert_b64_uid_to_unified_uid(file_id)
|
||||
for file_id in file_ids
|
||||
]
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
file_ids, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
|
@ -69,6 +75,28 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
|
||||
return data
|
||||
|
||||
def _is_base64_encoded_unified_file_id(
|
||||
self, b64_uid: str
|
||||
) -> Union[str, Literal[False]]:
|
||||
# Add padding back if needed
|
||||
padded = b64_uid + "=" * (-len(b64_uid) % 4)
|
||||
# Decode from base64
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
return decoded
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||
if is_base64_unified_file_id:
|
||||
return is_base64_unified_file_id
|
||||
else:
|
||||
return b64_uid
|
||||
|
||||
async def get_model_file_id_mapping(
|
||||
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||
) -> dict:
|
||||
|
@ -107,6 +135,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
)
|
||||
if cached_values:
|
||||
file_id_mapping[file_id] = cached_values
|
||||
|
||||
return file_id_mapping
|
||||
|
||||
@staticmethod
|
||||
|
@ -126,14 +155,19 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
file_type, str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Convert to URL-safe base64 and strip padding
|
||||
base64_unified_file_id = (
|
||||
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
## CREATE RESPONSE OBJECT
|
||||
## CREATE RESPONSE OBJECT
|
||||
|
||||
response = OpenAIFileObject(
|
||||
id=unified_file_id,
|
||||
id=base64_unified_file_id,
|
||||
object="file",
|
||||
purpose=cast(OpenAIFilesPurpose, purpose),
|
||||
created_at=file_objects[0].created_at,
|
||||
bytes=1234,
|
||||
bytes=file_objects[0].bytes,
|
||||
filename=str(datetime.now().timestamp()),
|
||||
status="uploaded",
|
||||
)
|
||||
|
|
|
@ -304,3 +304,108 @@ def test_create_file_and_call_chat_completion_e2e(
|
|||
finally:
|
||||
# Stop the mock
|
||||
mock.stop()
|
||||
|
||||
|
||||
def test_create_file_for_each_model(
|
||||
mocker: MockerFixture, monkeypatch, llm_router: Router
|
||||
):
|
||||
"""
|
||||
Test that create_file_for_each_model creates files for each target model and returns a unified file ID
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from litellm import CreateFileRequest
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||
create_file_for_each_model,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
|
||||
|
||||
# Setup proxy logging
|
||||
proxy_logging_obj = ProxyLogging(
|
||||
user_api_key_cache=DualCache(default_in_memory_ttl=1)
|
||||
)
|
||||
proxy_logging_obj._add_proxy_hooks(llm_router)
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_obj
|
||||
)
|
||||
|
||||
# Mock user API key dict
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
team_alias="test-team-alias",
|
||||
parent_otel_span=None,
|
||||
)
|
||||
|
||||
# Create test file request
|
||||
test_file_content = b"test file content"
|
||||
test_file = ("test.txt", test_file_content, "text/plain")
|
||||
_create_file_request = CreateFileRequest(file=test_file, purpose="user_data")
|
||||
|
||||
# Mock the router's acreate_file method
|
||||
mock_file_response = OpenAIFileObject(
|
||||
id="test-file-id",
|
||||
object="file",
|
||||
bytes=123,
|
||||
created_at=1234567890,
|
||||
filename="test.txt",
|
||||
purpose="user_data",
|
||||
status="uploaded",
|
||||
)
|
||||
mock_file_response._hidden_params = {"model_id": "test-model-id"}
|
||||
mocker.patch.object(llm_router, "acreate_file", return_value=mock_file_response)
|
||||
|
||||
# Call the function
|
||||
target_model_names_list = ["azure-gpt-3-5-turbo", "gpt-3.5-turbo"]
|
||||
response = asyncio.run(
|
||||
create_file_for_each_model(
|
||||
llm_router=llm_router,
|
||||
_create_file_request=_create_file_request,
|
||||
target_model_names_list=target_model_names_list,
|
||||
purpose="user_data",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, OpenAIFileObject)
|
||||
assert response.id is not None
|
||||
assert response.purpose == "user_data"
|
||||
assert response.filename == "test.txt"
|
||||
|
||||
# Verify acreate_file was called for each model
|
||||
assert llm_router.acreate_file.call_count == len(target_model_names_list)
|
||||
|
||||
# Get all calls made to acreate_file
|
||||
calls = llm_router.acreate_file.call_args_list
|
||||
|
||||
# Verify Azure call
|
||||
azure_call_found = False
|
||||
for call in calls:
|
||||
kwargs = call.kwargs
|
||||
if (
|
||||
kwargs.get("model") == "azure-gpt-3-5-turbo"
|
||||
and kwargs.get("file") == test_file
|
||||
and kwargs.get("purpose") == "user_data"
|
||||
):
|
||||
azure_call_found = True
|
||||
break
|
||||
assert azure_call_found, "Azure call not found with expected parameters"
|
||||
|
||||
# Verify OpenAI call
|
||||
openai_call_found = False
|
||||
for call in calls:
|
||||
kwargs = call.kwargs
|
||||
if (
|
||||
kwargs.get("model") == "gpt-3.5-turbo"
|
||||
and kwargs.get("file") == test_file
|
||||
and kwargs.get("purpose") == "user_data"
|
||||
):
|
||||
openai_call_found = True
|
||||
break
|
||||
assert openai_call_found, "OpenAI call not found with expected parameters"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue