feat(managed_files.py): support reading / writing files in DB

This commit is contained in:
Krrish Dholakia 2025-04-11 18:23:05 -07:00
parent b59e54d835
commit 49fbe6d3d2
6 changed files with 146 additions and 96 deletions

View file

@ -16,7 +16,7 @@ from pydantic import (
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
from litellm.types.integrations.slack_alerting import AlertType from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues, OpenAIFileObject
from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.router import RouterErrors, UpdateRouterConfig
from litellm.types.utils import ( from litellm.types.utils import (
CallTypes, CallTypes,
@ -144,6 +144,7 @@ class LitellmTableNames(str, enum.Enum):
USER_TABLE_NAME = "LiteLLM_UserTable" USER_TABLE_NAME = "LiteLLM_UserTable"
KEY_TABLE_NAME = "LiteLLM_VerificationToken" KEY_TABLE_NAME = "LiteLLM_VerificationToken"
PROXY_MODEL_TABLE_NAME = "LiteLLM_ProxyModelTable" PROXY_MODEL_TABLE_NAME = "LiteLLM_ProxyModelTable"
MANAGED_FILE_TABLE_NAME = "LiteLLM_ManagedFileTable"
class Litellm_EntityType(enum.Enum): class Litellm_EntityType(enum.Enum):
@ -2762,3 +2763,9 @@ class SpendUpdateQueueItem(TypedDict, total=False):
entity_type: Litellm_EntityType entity_type: Litellm_EntityType
entity_id: str entity_id: str
response_cost: Optional[float] response_cost: Optional[float]
class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):
unified_file_id: str
file_object: OpenAIFileObject
model_mappings: Dict[str, str]

View file

@ -1,8 +1,9 @@
# What is this? # What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id ## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import asyncio
import base64 import base64
import json
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
@ -11,7 +12,7 @@ from litellm import Router, verbose_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.proxy._types import CallTypes, UserAPIKeyAuth from litellm.proxy._types import CallTypes, LiteLLM_ManagedFileTable, UserAPIKeyAuth
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionFileObject, ChatCompletionFileObject,
@ -71,25 +72,43 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
file_id: str, file_id: str,
file_object: OpenAIFileObject, file_object: OpenAIFileObject,
litellm_parent_otel_span: Optional[Span], litellm_parent_otel_span: Optional[Span],
model_mappings: Dict[str, str],
) -> None: ) -> None:
key = f"litellm_proxy/{file_id}"
verbose_logger.info( verbose_logger.info(
f"Storing LiteLLM Managed File object with id={file_id} in cache" f"Storing LiteLLM Managed File object with id={file_id} in cache"
) )
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
key=key, key=file_id,
value=file_object, value=file_object,
litellm_parent_otel_span=litellm_parent_otel_span, litellm_parent_otel_span=litellm_parent_otel_span,
) )
await self.prisma_client.db.litellm_managedfiletable.create(
data={
"unified_file_id": file_id,
"file_object": file_object.model_dump_json(),
"model_mappings": json.dumps(model_mappings),
}
)
async def get_unified_file_id( async def get_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> Optional[OpenAIFileObject]: ) -> Optional[LiteLLM_ManagedFileTable]:
key = f"litellm_proxy/{file_id}" ## CHECK CACHE
return await self.internal_usage_cache.async_get_cache( result = await self.internal_usage_cache.async_get_cache(
key=key, key=file_id,
litellm_parent_otel_span=litellm_parent_otel_span, litellm_parent_otel_span=litellm_parent_otel_span,
) )
if result:
return result
## CHECK DB
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if db_object:
return db_object
return None
async def delete_unified_file_id( async def delete_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
@ -166,11 +185,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
file_object_file_field = file_object["file"] file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id") file_id = file_object_file_field.get("file_id")
if file_id: if file_id:
file_ids.append( file_ids.append(file_id)
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
file_id
)
)
file_object_file_field[ file_object_file_field[
"file_id" "file_id"
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid( ] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
@ -244,37 +259,76 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
# Get all cache keys matching the pattern file_id:* # Get all cache keys matching the pattern file_id:*
for file_id in litellm_managed_file_ids: for file_id in litellm_managed_file_ids:
# Search for any cache key starting with this file_id # Search for any cache key starting with this file_id
cached_values = cast( unified_file_object = await self.get_unified_file_id(
Dict[str, str], file_id, litellm_parent_otel_span
await self.internal_usage_cache.async_get_cache(
key=file_id, litellm_parent_otel_span=litellm_parent_otel_span
),
) )
if cached_values: if unified_file_object:
file_id_mapping[file_id] = cached_values file_id_mapping[file_id] = unified_file_object.model_mappings
return file_id_mapping return file_id_mapping
async def async_post_call_success_hook( async def create_file_for_each_model(
self, self,
data: Dict, llm_router: Optional[Router],
user_api_key_dict: UserAPIKeyAuth, _create_file_request: CreateFileRequest,
response: LLMResponseTypes, target_model_names_list: List[str],
) -> Any: litellm_parent_otel_span: Span,
if isinstance(response, OpenAIFileObject): ) -> List[OpenAIFileObject]:
asyncio.create_task( if llm_router is None:
self.store_unified_file_id( raise Exception("LLM Router not initialized. Ensure models added to proxy.")
response.id, response, user_api_key_dict.parent_otel_span responses = []
) for model in target_model_names_list:
individual_response = await llm_router.acreate_file(
model=model, **_create_file_request
) )
responses.append(individual_response)
return None return responses
async def acreate_file(
self,
create_file_request: CreateFileRequest,
llm_router: Router,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
responses = await self.create_file_for_each_model(
llm_router=llm_router,
_create_file_request=create_file_request,
target_model_names_list=target_model_names_list,
litellm_parent_otel_span=litellm_parent_otel_span,
)
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
file_objects=responses,
create_file_request=create_file_request,
internal_usage_cache=self.internal_usage_cache,
litellm_parent_otel_span=litellm_parent_otel_span,
)
## STORE MODEL MAPPINGS IN DB
model_mappings: Dict[str, str] = {}
for file_object in responses:
model_id = file_object._hidden_params.get("model_id")
if model_id is None:
verbose_logger.warning(
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
)
continue
file_id = file_object.id
model_mappings[model_id] = file_id
await self.store_unified_file_id(
file_id=response.id,
file_object=response,
litellm_parent_otel_span=litellm_parent_otel_span,
model_mappings=model_mappings,
)
return response
@staticmethod @staticmethod
async def return_unified_file_id( async def return_unified_file_id(
file_objects: List[OpenAIFileObject], file_objects: List[OpenAIFileObject],
create_file_request: CreateFileRequest, create_file_request: CreateFileRequest,
purpose: OpenAIFilesPurpose,
internal_usage_cache: InternalUsageCache, internal_usage_cache: InternalUsageCache,
litellm_parent_otel_span: Span, litellm_parent_otel_span: Span,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
@ -297,30 +351,13 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
response = OpenAIFileObject( response = OpenAIFileObject(
id=base64_unified_file_id, id=base64_unified_file_id,
object="file", object="file",
purpose=cast(OpenAIFilesPurpose, purpose), purpose=create_file_request["purpose"],
created_at=file_objects[0].created_at, created_at=file_objects[0].created_at,
bytes=file_objects[0].bytes, bytes=file_objects[0].bytes,
filename=file_objects[0].filename, filename=file_objects[0].filename,
status="uploaded", status="uploaded",
) )
## STORE RESPONSE IN DB + CACHE
stored_values: Dict[str, str] = {}
for file_object in file_objects:
model_id = file_object._hidden_params.get("model_id")
if model_id is None:
verbose_logger.warning(
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
)
continue
file_id = file_object.id
stored_values[model_id] = file_id
await internal_usage_cache.async_set_cache(
key=unified_file_id,
value=stored_values,
litellm_parent_otel_span=litellm_parent_otel_span,
)
return response return response
async def afile_retrieve( async def afile_retrieve(
@ -330,7 +367,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
file_id, litellm_parent_otel_span file_id, litellm_parent_otel_span
) )
if stored_file_object: if stored_file_object:
return stored_file_object return stored_file_object.file_object
else: else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found") raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
@ -376,12 +413,10 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
""" """
Get the content of a file from first model that has it Get the content of a file from first model that has it
""" """
initial_file_id = file_id
unified_file_id = self.convert_b64_uid_to_unified_uid(file_id)
model_file_id_mapping = await self.get_model_file_id_mapping( model_file_id_mapping = await self.get_model_file_id_mapping(
[unified_file_id], litellm_parent_otel_span [file_id], litellm_parent_otel_span
) )
specific_model_file_id_mapping = model_file_id_mapping.get(unified_file_id) specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping: if specific_model_file_id_mapping:
exception_dict = {} exception_dict = {}
for model_id, file_id in specific_model_file_id_mapping.items(): for model_id, file_id in specific_model_file_id_mapping.items():
@ -390,9 +425,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
except Exception as e: except Exception as e:
exception_dict[model_id] = str(e) exception_dict[model_id] = str(e)
raise Exception( raise Exception(
f"LiteLLM Managed File object with id={initial_file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}" f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
) )
else: else:
raise Exception( raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
f"LiteLLM Managed File object with id={initial_file_id} not found"
)

View file

@ -128,37 +128,6 @@ async def _deprecated_loadbalanced_create_file(
return response return response
async def create_file_for_each_model(
llm_router: Optional[Router],
_create_file_request: CreateFileRequest,
target_model_names_list: List[str],
purpose: OpenAIFilesPurpose,
proxy_logging_obj: ProxyLogging,
user_api_key_dict: UserAPIKeyAuth,
) -> OpenAIFileObject:
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
responses = []
for model in target_model_names_list:
individual_response = await llm_router.acreate_file(
model=model, **_create_file_request
)
responses.append(individual_response)
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
file_objects=responses,
create_file_request=_create_file_request,
purpose=purpose,
internal_usage_cache=proxy_logging_obj.internal_usage_cache,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
return response
async def route_create_file( async def route_create_file(
llm_router: Optional[Router], llm_router: Optional[Router],
_create_file_request: CreateFileRequest, _create_file_request: CreateFileRequest,
@ -181,13 +150,29 @@ async def route_create_file(
_create_file_request=_create_file_request, _create_file_request=_create_file_request,
) )
elif target_model_names_list: elif target_model_names_list:
response = await create_file_for_each_model( managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
proxy_logging_obj.get_proxy_hook("managed_files"),
)
if managed_files_obj is None:
raise ProxyException(
message="Managed files hook not found",
type="None",
param="None",
code=500,
)
if llm_router is None:
raise ProxyException(
message="LLM Router not found",
type="None",
param="None",
code=500,
)
response = await managed_files_obj.acreate_file(
llm_router=llm_router, llm_router=llm_router,
_create_file_request=_create_file_request, create_file_request=_create_file_request,
target_model_names_list=target_model_names_list, target_model_names_list=target_model_names_list,
purpose=purpose, litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
) )
else: else:
# get configs for custom_llm_provider # get configs for custom_llm_provider

View file

@ -354,3 +354,14 @@ enum JobStatus {
INACTIVE INACTIVE
} }
model LiteLLM_ManagedFileTable {
id String @id @default(uuid())
unified_file_id String @unique // The base64 encoded unified file ID
file_object Json // Stores the OpenAIFileObject
model_mappings Json // Stores the mapping of model_id -> provider_file_id
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@index([unified_file_id])
}

View file

@ -0,0 +1,3 @@
These are file-specific types for the proxy.
For Types you expect to be used across the proxy, put them in `litellm/proxy/_types.py`

View file

@ -354,3 +354,14 @@ enum JobStatus {
INACTIVE INACTIVE
} }
model LiteLLM_ManagedFileTable {
id String @id @default(uuid())
unified_file_id String @unique // The base64 encoded unified file ID
file_object Json // Stores the OpenAIFileObject
model_mappings Json // Stores the mapping of model_id -> provider_file_id
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@index([unified_file_id])
}