mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(managed_files.py): return unified file id as b64 str
allows retrieve file id to work as expected
This commit is contained in:
parent
b460025e18
commit
4993d9aa50
3 changed files with 157 additions and 5 deletions
|
@ -22,6 +22,8 @@ from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMExcepti
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
|
ChatCompletionFileObject,
|
||||||
|
ChatCompletionFileObjectFile,
|
||||||
ChatCompletionImageObject,
|
ChatCompletionImageObject,
|
||||||
ChatCompletionImageUrlObject,
|
ChatCompletionImageUrlObject,
|
||||||
)
|
)
|
||||||
|
@ -188,6 +190,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
message_content = message.get("content")
|
message_content = message.get("content")
|
||||||
if message_content and isinstance(message_content, list):
|
if message_content and isinstance(message_content, list):
|
||||||
for content_item in message_content:
|
for content_item in message_content:
|
||||||
|
litellm_specific_params = {"format"}
|
||||||
if content_item.get("type") == "image_url":
|
if content_item.get("type") == "image_url":
|
||||||
content_item = cast(ChatCompletionImageObject, content_item)
|
content_item = cast(ChatCompletionImageObject, content_item)
|
||||||
if isinstance(content_item["image_url"], str):
|
if isinstance(content_item["image_url"], str):
|
||||||
|
@ -195,7 +198,6 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
"url": content_item["image_url"],
|
"url": content_item["image_url"],
|
||||||
}
|
}
|
||||||
elif isinstance(content_item["image_url"], dict):
|
elif isinstance(content_item["image_url"], dict):
|
||||||
litellm_specific_params = {"format"}
|
|
||||||
new_image_url_obj = ChatCompletionImageUrlObject(
|
new_image_url_obj = ChatCompletionImageUrlObject(
|
||||||
**{ # type: ignore
|
**{ # type: ignore
|
||||||
k: v
|
k: v
|
||||||
|
@ -204,6 +206,17 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
content_item["image_url"] = new_image_url_obj
|
content_item["image_url"] = new_image_url_obj
|
||||||
|
elif content_item.get("type") == "file":
|
||||||
|
content_item = cast(ChatCompletionFileObject, content_item)
|
||||||
|
file_obj = content_item["file"]
|
||||||
|
new_file_obj = ChatCompletionFileObjectFile(
|
||||||
|
**{ # type: ignore
|
||||||
|
k: v
|
||||||
|
for k, v in file_obj.items()
|
||||||
|
if k not in litellm_specific_params
|
||||||
|
}
|
||||||
|
)
|
||||||
|
content_item["file"] = new_file_obj
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
|
@ -403,4 +416,4 @@ class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
|
||||||
choices=chunk["choices"],
|
choices=chunk["choices"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||||
|
|
||||||
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
||||||
|
@ -19,6 +20,7 @@ from litellm.types.llms.openai import (
|
||||||
OpenAIFilesPurpose,
|
OpenAIFilesPurpose,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import SpecialEnums
|
from litellm.types.utils import SpecialEnums
|
||||||
|
from litellm.utils import is_base64_encoded
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -62,6 +64,10 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
if messages:
|
if messages:
|
||||||
file_ids = get_file_ids_from_messages(messages)
|
file_ids = get_file_ids_from_messages(messages)
|
||||||
if file_ids:
|
if file_ids:
|
||||||
|
file_ids = [
|
||||||
|
self.convert_b64_uid_to_unified_uid(file_id)
|
||||||
|
for file_id in file_ids
|
||||||
|
]
|
||||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||||
file_ids, user_api_key_dict.parent_otel_span
|
file_ids, user_api_key_dict.parent_otel_span
|
||||||
)
|
)
|
||||||
|
@ -69,6 +75,28 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def _is_base64_encoded_unified_file_id(
|
||||||
|
self, b64_uid: str
|
||||||
|
) -> Union[str, Literal[False]]:
|
||||||
|
# Add padding back if needed
|
||||||
|
padded = b64_uid + "=" * (-len(b64_uid) % 4)
|
||||||
|
# Decode from base64
|
||||||
|
try:
|
||||||
|
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||||
|
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||||
|
return decoded
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||||
|
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
return is_base64_unified_file_id
|
||||||
|
else:
|
||||||
|
return b64_uid
|
||||||
|
|
||||||
async def get_model_file_id_mapping(
|
async def get_model_file_id_mapping(
|
||||||
self, file_ids: List[str], litellm_parent_otel_span: Span
|
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -107,6 +135,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
)
|
)
|
||||||
if cached_values:
|
if cached_values:
|
||||||
file_id_mapping[file_id] = cached_values
|
file_id_mapping[file_id] = cached_values
|
||||||
|
|
||||||
return file_id_mapping
|
return file_id_mapping
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -126,14 +155,19 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
file_type, str(uuid.uuid4())
|
file_type, str(uuid.uuid4())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert to URL-safe base64 and strip padding
|
||||||
|
base64_unified_file_id = (
|
||||||
|
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
|
||||||
|
)
|
||||||
|
|
||||||
## CREATE RESPONSE OBJECT
|
## CREATE RESPONSE OBJECT
|
||||||
## CREATE RESPONSE OBJECT
|
|
||||||
response = OpenAIFileObject(
|
response = OpenAIFileObject(
|
||||||
id=unified_file_id,
|
id=base64_unified_file_id,
|
||||||
object="file",
|
object="file",
|
||||||
purpose=cast(OpenAIFilesPurpose, purpose),
|
purpose=cast(OpenAIFilesPurpose, purpose),
|
||||||
created_at=file_objects[0].created_at,
|
created_at=file_objects[0].created_at,
|
||||||
bytes=1234,
|
bytes=file_objects[0].bytes,
|
||||||
filename=str(datetime.now().timestamp()),
|
filename=str(datetime.now().timestamp()),
|
||||||
status="uploaded",
|
status="uploaded",
|
||||||
)
|
)
|
||||||
|
|
|
@ -304,3 +304,108 @@ def test_create_file_and_call_chat_completion_e2e(
|
||||||
finally:
|
finally:
|
||||||
# Stop the mock
|
# Stop the mock
|
||||||
mock.stop()
|
mock.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_file_for_each_model(
|
||||||
|
mocker: MockerFixture, monkeypatch, llm_router: Router
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that create_file_for_each_model creates files for each target model and returns a unified file ID
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from litellm import CreateFileRequest
|
||||||
|
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||||
|
create_file_for_each_model,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
|
||||||
|
|
||||||
|
# Setup proxy logging
|
||||||
|
proxy_logging_obj = ProxyLogging(
|
||||||
|
user_api_key_cache=DualCache(default_in_memory_ttl=1)
|
||||||
|
)
|
||||||
|
proxy_logging_obj._add_proxy_hooks(llm_router)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock user API key dict
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key="test-key",
|
||||||
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
|
user_id="test-user",
|
||||||
|
team_id="test-team",
|
||||||
|
team_alias="test-team-alias",
|
||||||
|
parent_otel_span=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create test file request
|
||||||
|
test_file_content = b"test file content"
|
||||||
|
test_file = ("test.txt", test_file_content, "text/plain")
|
||||||
|
_create_file_request = CreateFileRequest(file=test_file, purpose="user_data")
|
||||||
|
|
||||||
|
# Mock the router's acreate_file method
|
||||||
|
mock_file_response = OpenAIFileObject(
|
||||||
|
id="test-file-id",
|
||||||
|
object="file",
|
||||||
|
bytes=123,
|
||||||
|
created_at=1234567890,
|
||||||
|
filename="test.txt",
|
||||||
|
purpose="user_data",
|
||||||
|
status="uploaded",
|
||||||
|
)
|
||||||
|
mock_file_response._hidden_params = {"model_id": "test-model-id"}
|
||||||
|
mocker.patch.object(llm_router, "acreate_file", return_value=mock_file_response)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
target_model_names_list = ["azure-gpt-3-5-turbo", "gpt-3.5-turbo"]
|
||||||
|
response = asyncio.run(
|
||||||
|
create_file_for_each_model(
|
||||||
|
llm_router=llm_router,
|
||||||
|
_create_file_request=_create_file_request,
|
||||||
|
target_model_names_list=target_model_names_list,
|
||||||
|
purpose="user_data",
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert isinstance(response, OpenAIFileObject)
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.purpose == "user_data"
|
||||||
|
assert response.filename == "test.txt"
|
||||||
|
|
||||||
|
# Verify acreate_file was called for each model
|
||||||
|
assert llm_router.acreate_file.call_count == len(target_model_names_list)
|
||||||
|
|
||||||
|
# Get all calls made to acreate_file
|
||||||
|
calls = llm_router.acreate_file.call_args_list
|
||||||
|
|
||||||
|
# Verify Azure call
|
||||||
|
azure_call_found = False
|
||||||
|
for call in calls:
|
||||||
|
kwargs = call.kwargs
|
||||||
|
if (
|
||||||
|
kwargs.get("model") == "azure-gpt-3-5-turbo"
|
||||||
|
and kwargs.get("file") == test_file
|
||||||
|
and kwargs.get("purpose") == "user_data"
|
||||||
|
):
|
||||||
|
azure_call_found = True
|
||||||
|
break
|
||||||
|
assert azure_call_found, "Azure call not found with expected parameters"
|
||||||
|
|
||||||
|
# Verify OpenAI call
|
||||||
|
openai_call_found = False
|
||||||
|
for call in calls:
|
||||||
|
kwargs = call.kwargs
|
||||||
|
if (
|
||||||
|
kwargs.get("model") == "gpt-3.5-turbo"
|
||||||
|
and kwargs.get("file") == test_file
|
||||||
|
and kwargs.get("purpose") == "user_data"
|
||||||
|
):
|
||||||
|
openai_call_found = True
|
||||||
|
break
|
||||||
|
assert openai_call_found, "OpenAI call not found with expected parameters"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue