Litellm add managed files db (#9930)

* fix(openai.py): ensure openai file object shows up on logs

* fix(managed_files.py): return unified file id as b64 str

allows retrieve file id to work as expected

* fix(managed_files.py): apply decoded file id transformation

* fix: add unit test for file id + decode logic

* fix: initial commit for litellm_proxy support with CRUD Endpoints

* fix(managed_files.py): support retrieve file operation

* fix(managed_files.py): support for DELETE endpoint for files

* fix(managed_files.py): retrieve file content support

supports retrieve file content api from openai

* fix: fix linting error

* test: update tests

* fix: fix linting error

* feat(managed_files.py): support reading / writing files in DB

* feat(managed_files.py): support deleting file from DB on delete

* test: update testing

* fix(spend_tracking_utils.py): ensure each file create request is logged correctly

* fix(managed_files.py): fix storing / returning managed file object from cache

* fix(files/main.py): pass litellm params to azure route

* test: fix test

* build: add new prisma migration

* build: bump requirements

* test: add more testing

* refactor: cleanup post merge w/ main

* fix: fix code qa errors
This commit is contained in:
Krish Dholakia 2025-04-12 08:24:46 -07:00 committed by GitHub
parent 93037ea4d3
commit 421e0a3004
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 286 additions and 158 deletions

View file

@ -1,8 +1,8 @@
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import asyncio
import base64
import json
import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
@ -11,7 +11,7 @@ from litellm import Router, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
from litellm.proxy._types import CallTypes, LiteLLM_ManagedFileTable, UserAPIKeyAuth
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
@ -19,7 +19,7 @@ from litellm.types.llms.openai import (
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import LLMResponseTypes, SpecialEnums
from litellm.types.utils import SpecialEnums
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -71,44 +71,73 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
file_id: str,
file_object: OpenAIFileObject,
litellm_parent_otel_span: Optional[Span],
model_mappings: Dict[str, str],
) -> None:
key = f"litellm_proxy/{file_id}"
verbose_logger.info(
f"Storing LiteLLM Managed File object with id={file_id} in cache"
)
litellm_managed_file_object = LiteLLM_ManagedFileTable(
unified_file_id=file_id,
file_object=file_object,
model_mappings=model_mappings,
)
await self.internal_usage_cache.async_set_cache(
key=key,
value=file_object,
key=file_id,
value=litellm_managed_file_object.model_dump(),
litellm_parent_otel_span=litellm_parent_otel_span,
)
await self.prisma_client.db.litellm_managedfiletable.create(
data={
"unified_file_id": file_id,
"file_object": file_object.model_dump_json(),
"model_mappings": json.dumps(model_mappings),
}
)
async def get_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> Optional[OpenAIFileObject]:
key = f"litellm_proxy/{file_id}"
return await self.internal_usage_cache.async_get_cache(
key=key,
litellm_parent_otel_span=litellm_parent_otel_span,
) -> Optional[LiteLLM_ManagedFileTable]:
## CHECK CACHE
result = cast(
Optional[dict],
await self.internal_usage_cache.async_get_cache(
key=file_id,
litellm_parent_otel_span=litellm_parent_otel_span,
),
)
if result:
return LiteLLM_ManagedFileTable(**result)
## CHECK DB
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if db_object:
return LiteLLM_ManagedFileTable(**db_object.model_dump())
return None
async def delete_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> OpenAIFileObject:
key = f"litellm_proxy/{file_id}"
## get old value
old_value = await self.internal_usage_cache.async_get_cache(
key=key,
litellm_parent_otel_span=litellm_parent_otel_span,
initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if old_value is None or not isinstance(old_value, OpenAIFileObject):
if initial_value is None:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
## delete old value
await self.internal_usage_cache.async_set_cache(
key=key,
key=file_id,
value=None,
litellm_parent_otel_span=litellm_parent_otel_span,
)
return old_value
await self.prisma_client.db.litellm_managedfiletable.delete(
where={"unified_file_id": file_id}
)
return initial_value.file_object
async def async_pre_call_hook(
self,
@ -133,11 +162,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
if call_type == CallTypes.completion.value:
messages = data.get("messages")
if messages:
file_ids = (
self.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
messages
)
)
file_ids = self.get_file_ids_from_messages(messages)
if file_ids:
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
@ -147,9 +172,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
return data
def get_file_ids_and_decode_b64_to_unified_uid_from_messages(
self, messages: List[AllMessageValues]
) -> List[str]:
def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]:
"""
Gets file ids from messages
"""
@ -166,16 +189,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
file_id
)
)
file_object_file_field[
"file_id"
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
file_id
)
file_ids.append(file_id)
return file_ids
@staticmethod
@ -236,45 +250,82 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
if is_base64_unified_file_id:
litellm_managed_file_ids.append(is_base64_unified_file_id)
elif file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
litellm_managed_file_ids.append(file_id)
if litellm_managed_file_ids:
# Get all cache keys matching the pattern file_id:*
for file_id in litellm_managed_file_ids:
# Search for any cache key starting with this file_id
cached_values = cast(
Dict[str, str],
await self.internal_usage_cache.async_get_cache(
key=file_id, litellm_parent_otel_span=litellm_parent_otel_span
),
unified_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if cached_values:
file_id_mapping[file_id] = cached_values
if unified_file_object:
file_id_mapping[file_id] = unified_file_object.model_mappings
return file_id_mapping
async def async_post_call_success_hook(
async def create_file_for_each_model(
self,
data: Dict,
user_api_key_dict: UserAPIKeyAuth,
response: LLMResponseTypes,
) -> Any:
if isinstance(response, OpenAIFileObject):
asyncio.create_task(
self.store_unified_file_id(
response.id, response, user_api_key_dict.parent_otel_span
)
llm_router: Optional[Router],
_create_file_request: CreateFileRequest,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> List[OpenAIFileObject]:
if llm_router is None:
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
responses = []
for model in target_model_names_list:
individual_response = await llm_router.acreate_file(
model=model, **_create_file_request
)
responses.append(individual_response)
return None
return responses
async def acreate_file(
self,
create_file_request: CreateFileRequest,
llm_router: Router,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
responses = await self.create_file_for_each_model(
llm_router=llm_router,
_create_file_request=create_file_request,
target_model_names_list=target_model_names_list,
litellm_parent_otel_span=litellm_parent_otel_span,
)
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
file_objects=responses,
create_file_request=create_file_request,
internal_usage_cache=self.internal_usage_cache,
litellm_parent_otel_span=litellm_parent_otel_span,
)
## STORE MODEL MAPPINGS IN DB
model_mappings: Dict[str, str] = {}
for file_object in responses:
model_id = file_object._hidden_params.get("model_id")
if model_id is None:
verbose_logger.warning(
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
)
continue
file_id = file_object.id
model_mappings[model_id] = file_id
await self.store_unified_file_id(
file_id=response.id,
file_object=response,
litellm_parent_otel_span=litellm_parent_otel_span,
model_mappings=model_mappings,
)
return response
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
create_file_request: CreateFileRequest,
purpose: OpenAIFilesPurpose,
internal_usage_cache: InternalUsageCache,
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
@ -297,30 +348,13 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
response = OpenAIFileObject(
id=base64_unified_file_id,
object="file",
purpose=cast(OpenAIFilesPurpose, purpose),
purpose=create_file_request["purpose"],
created_at=file_objects[0].created_at,
bytes=file_objects[0].bytes,
filename=file_objects[0].filename,
status="uploaded",
)
## STORE RESPONSE IN DB + CACHE
stored_values: Dict[str, str] = {}
for file_object in file_objects:
model_id = file_object._hidden_params.get("model_id")
if model_id is None:
verbose_logger.warning(
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
)
continue
file_id = file_object.id
stored_values[model_id] = file_id
await internal_usage_cache.async_set_cache(
key=unified_file_id,
value=stored_values,
litellm_parent_otel_span=litellm_parent_otel_span,
)
return response
async def afile_retrieve(
@ -330,7 +364,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object
return stored_file_object.file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
@ -376,12 +410,11 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
"""
Get the content of a file from first model that has it
"""
initial_file_id = file_id
unified_file_id = self.convert_b64_uid_to_unified_uid(file_id)
model_file_id_mapping = await self.get_model_file_id_mapping(
[unified_file_id], litellm_parent_otel_span
[file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(unified_file_id)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
exception_dict = {}
for model_id, file_id in specific_model_file_id_mapping.items():
@ -390,9 +423,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
except Exception as e:
exception_dict[model_id] = str(e)
raise Exception(
f"LiteLLM Managed File object with id={initial_file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
)
else:
raise Exception(
f"LiteLLM Managed File object with id={initial_file_id} not found"
)
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")