mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(managed_files.py): support retrieve file operation
This commit is contained in:
parent
7fff83e441
commit
cbcf028da5
3 changed files with 92 additions and 23 deletions
|
@ -40,7 +40,9 @@ else:
|
|||
class BaseFileEndpoints(ABC):
|
||||
@abstractmethod
|
||||
async def afile_retrieve(
|
||||
self, custom_llm_provider: str, file_id: str, **data: dict
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
|
@ -66,12 +68,29 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
|||
self.prisma_client = prisma_client
|
||||
|
||||
async def store_unified_file_id(
|
||||
self, file_id: str, file_object: OpenAIFileObject
|
||||
self,
|
||||
file_id: str,
|
||||
file_object: OpenAIFileObject,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
) -> None:
|
||||
pass
|
||||
key = f"litellm_proxy/{file_id}"
|
||||
verbose_logger.info(
|
||||
f"Storing LiteLLM Managed File object with id={file_id} in cache"
|
||||
)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=key,
|
||||
value=file_object,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
async def get_unified_file_id(self, file_id: str) -> Optional[OpenAIFileObject]:
|
||||
return None
|
||||
async def get_unified_file_id(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||
) -> Optional[OpenAIFileObject]:
|
||||
key = f"litellm_proxy/{file_id}"
|
||||
return await self.internal_usage_cache.async_get_cache(
|
||||
key=key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
|
@ -130,23 +149,29 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
|||
file_id = file_object_file_field.get("file_id")
|
||||
if file_id:
|
||||
file_ids.append(
|
||||
self._convert_b64_uid_to_unified_uid(file_id)
|
||||
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
||||
file_id
|
||||
)
|
||||
)
|
||||
file_object_file_field[
|
||||
"file_id"
|
||||
] = self._convert_b64_uid_to_unified_uid(file_id)
|
||||
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
||||
file_id
|
||||
)
|
||||
return file_ids
|
||||
|
||||
def _convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||
@staticmethod
|
||||
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = (
|
||||
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
|
||||
)
|
||||
if is_base64_unified_file_id:
|
||||
return is_base64_unified_file_id
|
||||
else:
|
||||
return b64_uid
|
||||
|
||||
def _is_base64_encoded_unified_file_id(
|
||||
self, b64_uid: str
|
||||
) -> Union[str, Literal[False]]:
|
||||
@staticmethod
|
||||
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
|
||||
# Add padding back if needed
|
||||
padded = b64_uid + "=" * (-len(b64_uid) % 4)
|
||||
# Decode from base64
|
||||
|
@ -212,14 +237,18 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
|||
|
||||
return file_id_mapping
|
||||
|
||||
async def post_call_success_hook(
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: Dict,
|
||||
response: LLMResponseTypes,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: LLMResponseTypes,
|
||||
) -> Any:
|
||||
if isinstance(response, OpenAIFileObject):
|
||||
asyncio.create_task(self.store_unified_file_id(response.id, response))
|
||||
asyncio.create_task(
|
||||
self.store_unified_file_id(
|
||||
response.id, response, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -277,13 +306,15 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
|||
return response
|
||||
|
||||
async def afile_retrieve(
|
||||
self, custom_llm_provider: str, file_id: str, **data: Dict
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span]
|
||||
) -> OpenAIFileObject:
|
||||
stored_file_object = await self.get_unified_file_id(file_id)
|
||||
stored_file_object = await self.get_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if stored_file_object:
|
||||
return stored_file_object
|
||||
else:
|
||||
raise Exception(f"File object with id={file_id} not found")
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
|
||||
async def afile_list(
|
||||
self, custom_llm_provider: str, **data: Dict
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue