mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(managed_files.py): support reading / writing files in DB
This commit is contained in:
parent
b59e54d835
commit
49fbe6d3d2
6 changed files with 146 additions and 96 deletions
|
@ -16,7 +16,7 @@ from pydantic import (
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
||||||
from litellm.types.integrations.slack_alerting import AlertType
|
from litellm.types.integrations.slack_alerting import AlertType
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues, OpenAIFileObject
|
||||||
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
CallTypes,
|
CallTypes,
|
||||||
|
@ -144,6 +144,7 @@ class LitellmTableNames(str, enum.Enum):
|
||||||
USER_TABLE_NAME = "LiteLLM_UserTable"
|
USER_TABLE_NAME = "LiteLLM_UserTable"
|
||||||
KEY_TABLE_NAME = "LiteLLM_VerificationToken"
|
KEY_TABLE_NAME = "LiteLLM_VerificationToken"
|
||||||
PROXY_MODEL_TABLE_NAME = "LiteLLM_ProxyModelTable"
|
PROXY_MODEL_TABLE_NAME = "LiteLLM_ProxyModelTable"
|
||||||
|
MANAGED_FILE_TABLE_NAME = "LiteLLM_ManagedFileTable"
|
||||||
|
|
||||||
|
|
||||||
class Litellm_EntityType(enum.Enum):
|
class Litellm_EntityType(enum.Enum):
|
||||||
|
@ -2762,3 +2763,9 @@ class SpendUpdateQueueItem(TypedDict, total=False):
|
||||||
entity_type: Litellm_EntityType
|
entity_type: Litellm_EntityType
|
||||||
entity_id: str
|
entity_id: str
|
||||||
response_cost: Optional[float]
|
response_cost: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):
|
||||||
|
unified_file_id: str
|
||||||
|
file_object: OpenAIFileObject
|
||||||
|
model_mappings: Dict[str, str]
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||||
|
@ -11,7 +12,7 @@ from litellm import Router, verbose_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||||
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
from litellm.proxy._types import CallTypes, LiteLLM_ManagedFileTable, UserAPIKeyAuth
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
ChatCompletionFileObject,
|
ChatCompletionFileObject,
|
||||||
|
@ -71,25 +72,43 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
file_object: OpenAIFileObject,
|
file_object: OpenAIFileObject,
|
||||||
litellm_parent_otel_span: Optional[Span],
|
litellm_parent_otel_span: Optional[Span],
|
||||||
|
model_mappings: Dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
key = f"litellm_proxy/{file_id}"
|
|
||||||
verbose_logger.info(
|
verbose_logger.info(
|
||||||
f"Storing LiteLLM Managed File object with id={file_id} in cache"
|
f"Storing LiteLLM Managed File object with id={file_id} in cache"
|
||||||
)
|
)
|
||||||
await self.internal_usage_cache.async_set_cache(
|
await self.internal_usage_cache.async_set_cache(
|
||||||
key=key,
|
key=file_id,
|
||||||
value=file_object,
|
value=file_object,
|
||||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.prisma_client.db.litellm_managedfiletable.create(
|
||||||
|
data={
|
||||||
|
"unified_file_id": file_id,
|
||||||
|
"file_object": file_object.model_dump_json(),
|
||||||
|
"model_mappings": json.dumps(model_mappings),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
async def get_unified_file_id(
|
async def get_unified_file_id(
|
||||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||||
) -> Optional[OpenAIFileObject]:
|
) -> Optional[LiteLLM_ManagedFileTable]:
|
||||||
key = f"litellm_proxy/{file_id}"
|
## CHECK CACHE
|
||||||
return await self.internal_usage_cache.async_get_cache(
|
result = await self.internal_usage_cache.async_get_cache(
|
||||||
key=key,
|
key=file_id,
|
||||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
)
|
)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
## CHECK DB
|
||||||
|
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first(
|
||||||
|
where={"unified_file_id": file_id}
|
||||||
|
)
|
||||||
|
if db_object:
|
||||||
|
return db_object
|
||||||
|
return None
|
||||||
|
|
||||||
async def delete_unified_file_id(
|
async def delete_unified_file_id(
|
||||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||||
|
@ -166,11 +185,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
file_object_file_field = file_object["file"]
|
file_object_file_field = file_object["file"]
|
||||||
file_id = file_object_file_field.get("file_id")
|
file_id = file_object_file_field.get("file_id")
|
||||||
if file_id:
|
if file_id:
|
||||||
file_ids.append(
|
file_ids.append(file_id)
|
||||||
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
|
||||||
file_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
file_object_file_field[
|
file_object_file_field[
|
||||||
"file_id"
|
"file_id"
|
||||||
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
||||||
|
@ -244,37 +259,76 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
# Get all cache keys matching the pattern file_id:*
|
# Get all cache keys matching the pattern file_id:*
|
||||||
for file_id in litellm_managed_file_ids:
|
for file_id in litellm_managed_file_ids:
|
||||||
# Search for any cache key starting with this file_id
|
# Search for any cache key starting with this file_id
|
||||||
cached_values = cast(
|
unified_file_object = await self.get_unified_file_id(
|
||||||
Dict[str, str],
|
file_id, litellm_parent_otel_span
|
||||||
await self.internal_usage_cache.async_get_cache(
|
|
||||||
key=file_id, litellm_parent_otel_span=litellm_parent_otel_span
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if cached_values:
|
if unified_file_object:
|
||||||
file_id_mapping[file_id] = cached_values
|
file_id_mapping[file_id] = unified_file_object.model_mappings
|
||||||
|
|
||||||
return file_id_mapping
|
return file_id_mapping
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def create_file_for_each_model(
|
||||||
self,
|
self,
|
||||||
data: Dict,
|
llm_router: Optional[Router],
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
_create_file_request: CreateFileRequest,
|
||||||
response: LLMResponseTypes,
|
target_model_names_list: List[str],
|
||||||
) -> Any:
|
litellm_parent_otel_span: Span,
|
||||||
if isinstance(response, OpenAIFileObject):
|
) -> List[OpenAIFileObject]:
|
||||||
asyncio.create_task(
|
if llm_router is None:
|
||||||
self.store_unified_file_id(
|
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
|
||||||
response.id, response, user_api_key_dict.parent_otel_span
|
responses = []
|
||||||
)
|
for model in target_model_names_list:
|
||||||
|
individual_response = await llm_router.acreate_file(
|
||||||
|
model=model, **_create_file_request
|
||||||
)
|
)
|
||||||
|
responses.append(individual_response)
|
||||||
|
|
||||||
return None
|
return responses
|
||||||
|
|
||||||
|
async def acreate_file(
|
||||||
|
self,
|
||||||
|
create_file_request: CreateFileRequest,
|
||||||
|
llm_router: Router,
|
||||||
|
target_model_names_list: List[str],
|
||||||
|
litellm_parent_otel_span: Span,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
responses = await self.create_file_for_each_model(
|
||||||
|
llm_router=llm_router,
|
||||||
|
_create_file_request=create_file_request,
|
||||||
|
target_model_names_list=target_model_names_list,
|
||||||
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
|
)
|
||||||
|
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
|
||||||
|
file_objects=responses,
|
||||||
|
create_file_request=create_file_request,
|
||||||
|
internal_usage_cache=self.internal_usage_cache,
|
||||||
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
|
)
|
||||||
|
|
||||||
|
## STORE MODEL MAPPINGS IN DB
|
||||||
|
model_mappings: Dict[str, str] = {}
|
||||||
|
for file_object in responses:
|
||||||
|
model_id = file_object._hidden_params.get("model_id")
|
||||||
|
if model_id is None:
|
||||||
|
verbose_logger.warning(
|
||||||
|
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
file_id = file_object.id
|
||||||
|
model_mappings[model_id] = file_id
|
||||||
|
|
||||||
|
await self.store_unified_file_id(
|
||||||
|
file_id=response.id,
|
||||||
|
file_object=response,
|
||||||
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
|
model_mappings=model_mappings,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def return_unified_file_id(
|
async def return_unified_file_id(
|
||||||
file_objects: List[OpenAIFileObject],
|
file_objects: List[OpenAIFileObject],
|
||||||
create_file_request: CreateFileRequest,
|
create_file_request: CreateFileRequest,
|
||||||
purpose: OpenAIFilesPurpose,
|
|
||||||
internal_usage_cache: InternalUsageCache,
|
internal_usage_cache: InternalUsageCache,
|
||||||
litellm_parent_otel_span: Span,
|
litellm_parent_otel_span: Span,
|
||||||
) -> OpenAIFileObject:
|
) -> OpenAIFileObject:
|
||||||
|
@ -297,30 +351,13 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
response = OpenAIFileObject(
|
response = OpenAIFileObject(
|
||||||
id=base64_unified_file_id,
|
id=base64_unified_file_id,
|
||||||
object="file",
|
object="file",
|
||||||
purpose=cast(OpenAIFilesPurpose, purpose),
|
purpose=create_file_request["purpose"],
|
||||||
created_at=file_objects[0].created_at,
|
created_at=file_objects[0].created_at,
|
||||||
bytes=file_objects[0].bytes,
|
bytes=file_objects[0].bytes,
|
||||||
filename=file_objects[0].filename,
|
filename=file_objects[0].filename,
|
||||||
status="uploaded",
|
status="uploaded",
|
||||||
)
|
)
|
||||||
|
|
||||||
## STORE RESPONSE IN DB + CACHE
|
|
||||||
stored_values: Dict[str, str] = {}
|
|
||||||
for file_object in file_objects:
|
|
||||||
model_id = file_object._hidden_params.get("model_id")
|
|
||||||
if model_id is None:
|
|
||||||
verbose_logger.warning(
|
|
||||||
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
file_id = file_object.id
|
|
||||||
stored_values[model_id] = file_id
|
|
||||||
await internal_usage_cache.async_set_cache(
|
|
||||||
key=unified_file_id,
|
|
||||||
value=stored_values,
|
|
||||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def afile_retrieve(
|
async def afile_retrieve(
|
||||||
|
@ -330,7 +367,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
file_id, litellm_parent_otel_span
|
file_id, litellm_parent_otel_span
|
||||||
)
|
)
|
||||||
if stored_file_object:
|
if stored_file_object:
|
||||||
return stored_file_object
|
return stored_file_object.file_object
|
||||||
else:
|
else:
|
||||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||||
|
|
||||||
|
@ -376,12 +413,10 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
"""
|
"""
|
||||||
Get the content of a file from first model that has it
|
Get the content of a file from first model that has it
|
||||||
"""
|
"""
|
||||||
initial_file_id = file_id
|
|
||||||
unified_file_id = self.convert_b64_uid_to_unified_uid(file_id)
|
|
||||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||||
[unified_file_id], litellm_parent_otel_span
|
[file_id], litellm_parent_otel_span
|
||||||
)
|
)
|
||||||
specific_model_file_id_mapping = model_file_id_mapping.get(unified_file_id)
|
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
|
||||||
if specific_model_file_id_mapping:
|
if specific_model_file_id_mapping:
|
||||||
exception_dict = {}
|
exception_dict = {}
|
||||||
for model_id, file_id in specific_model_file_id_mapping.items():
|
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||||
|
@ -390,9 +425,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception_dict[model_id] = str(e)
|
exception_dict[model_id] = str(e)
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"LiteLLM Managed File object with id={initial_file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
|
f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||||
f"LiteLLM Managed File object with id={initial_file_id} not found"
|
|
||||||
)
|
|
||||||
|
|
|
@ -128,37 +128,6 @@ async def _deprecated_loadbalanced_create_file(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def create_file_for_each_model(
|
|
||||||
llm_router: Optional[Router],
|
|
||||||
_create_file_request: CreateFileRequest,
|
|
||||||
target_model_names_list: List[str],
|
|
||||||
purpose: OpenAIFilesPurpose,
|
|
||||||
proxy_logging_obj: ProxyLogging,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
|
||||||
) -> OpenAIFileObject:
|
|
||||||
if llm_router is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail={
|
|
||||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
responses = []
|
|
||||||
for model in target_model_names_list:
|
|
||||||
individual_response = await llm_router.acreate_file(
|
|
||||||
model=model, **_create_file_request
|
|
||||||
)
|
|
||||||
responses.append(individual_response)
|
|
||||||
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
|
|
||||||
file_objects=responses,
|
|
||||||
create_file_request=_create_file_request,
|
|
||||||
purpose=purpose,
|
|
||||||
internal_usage_cache=proxy_logging_obj.internal_usage_cache,
|
|
||||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def route_create_file(
|
async def route_create_file(
|
||||||
llm_router: Optional[Router],
|
llm_router: Optional[Router],
|
||||||
_create_file_request: CreateFileRequest,
|
_create_file_request: CreateFileRequest,
|
||||||
|
@ -181,13 +150,29 @@ async def route_create_file(
|
||||||
_create_file_request=_create_file_request,
|
_create_file_request=_create_file_request,
|
||||||
)
|
)
|
||||||
elif target_model_names_list:
|
elif target_model_names_list:
|
||||||
response = await create_file_for_each_model(
|
managed_files_obj = cast(
|
||||||
|
Optional[_PROXY_LiteLLMManagedFiles],
|
||||||
|
proxy_logging_obj.get_proxy_hook("managed_files"),
|
||||||
|
)
|
||||||
|
if managed_files_obj is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Managed files hook not found",
|
||||||
|
type="None",
|
||||||
|
param="None",
|
||||||
|
code=500,
|
||||||
|
)
|
||||||
|
if llm_router is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="LLM Router not found",
|
||||||
|
type="None",
|
||||||
|
param="None",
|
||||||
|
code=500,
|
||||||
|
)
|
||||||
|
response = await managed_files_obj.acreate_file(
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
_create_file_request=_create_file_request,
|
create_file_request=_create_file_request,
|
||||||
target_model_names_list=target_model_names_list,
|
target_model_names_list=target_model_names_list,
|
||||||
purpose=purpose,
|
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# get configs for custom_llm_provider
|
# get configs for custom_llm_provider
|
||||||
|
|
|
@ -354,3 +354,14 @@ enum JobStatus {
|
||||||
INACTIVE
|
INACTIVE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model LiteLLM_ManagedFileTable {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
unified_file_id String @unique // The base64 encoded unified file ID
|
||||||
|
file_object Json // Stores the OpenAIFileObject
|
||||||
|
model_mappings Json // Stores the mapping of model_id -> provider_file_id
|
||||||
|
created_at DateTime @default(now())
|
||||||
|
updated_at DateTime @updatedAt
|
||||||
|
|
||||||
|
@@index([unified_file_id])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
3
litellm/types/proxy/README.md
Normal file
3
litellm/types/proxy/README.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
These are file-specific types for the proxy.
|
||||||
|
|
||||||
|
For Types you expect to be used across the proxy, put them in `litellm/proxy/_types.py`
|
|
@ -354,3 +354,14 @@ enum JobStatus {
|
||||||
INACTIVE
|
INACTIVE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model LiteLLM_ManagedFileTable {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
unified_file_id String @unique // The base64 encoded unified file ID
|
||||||
|
file_object Json // Stores the OpenAIFileObject
|
||||||
|
model_mappings Json // Stores the mapping of model_id -> provider_file_id
|
||||||
|
created_at DateTime @default(now())
|
||||||
|
updated_at DateTime @updatedAt
|
||||||
|
|
||||||
|
@@index([unified_file_id])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue