mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix: initial commit for litellm_proxy support with CRUD Endpoints
This commit is contained in:
parent
6b04b48b17
commit
7fff83e441
6 changed files with 197 additions and 9 deletions
|
@ -20,8 +20,7 @@ from litellm.types.integrations.argilla import ArgillaItem
|
|||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
||||
from litellm.types.utils import (
|
||||
AdapterCompletionStreamWrapper,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
LLMResponseTypes,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StandardCallbackDynamicParams,
|
||||
|
@ -223,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
response: LLMResponseTypes,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
# What is this?
|
||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
|
@ -18,24 +20,58 @@ from litellm.types.llms.openai import (
|
|||
OpenAIFileObject,
|
||||
OpenAIFilesPurpose,
|
||||
)
|
||||
from litellm.types.utils import SpecialEnums
|
||||
from litellm.types.utils import LLMResponseTypes, SpecialEnums
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
from litellm.proxy.utils import PrismaClient as _PrismaClient
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
PrismaClient = _PrismaClient
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||
class BaseFileEndpoints(ABC):
|
||||
@abstractmethod
|
||||
async def afile_retrieve(
|
||||
self, custom_llm_provider: str, file_id: str, **data: dict
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_list(
|
||||
self, custom_llm_provider: str, **data: dict
|
||||
) -> List[OpenAIFileObject]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_delete(
|
||||
self, custom_llm_provider: str, file_id: str, **data: dict
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
|
||||
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
def __init__(
|
||||
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
|
||||
):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
async def store_unified_file_id(
|
||||
self, file_id: str, file_object: OpenAIFileObject
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def get_unified_file_id(self, file_id: str) -> Optional[OpenAIFileObject]:
|
||||
return None
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
|
@ -176,6 +212,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
|
||||
return file_id_mapping
|
||||
|
||||
async def post_call_success_hook(
|
||||
self,
|
||||
data: Dict,
|
||||
response: LLMResponseTypes,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Any:
|
||||
if isinstance(response, OpenAIFileObject):
|
||||
asyncio.create_task(self.store_unified_file_id(response.id, response))
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def return_unified_file_id(
|
||||
file_objects: List[OpenAIFileObject],
|
||||
|
@ -228,3 +275,22 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
|||
)
|
||||
|
||||
return response
|
||||
|
||||
async def afile_retrieve(
|
||||
self, custom_llm_provider: str, file_id: str, **data: Dict
|
||||
) -> OpenAIFileObject:
|
||||
stored_file_object = await self.get_unified_file_id(file_id)
|
||||
if stored_file_object:
|
||||
return stored_file_object
|
||||
else:
|
||||
raise Exception(f"File object with id={file_id} not found")
|
||||
|
||||
async def afile_list(
|
||||
self, custom_llm_provider: str, **data: Dict
|
||||
) -> List[OpenAIFileObject]:
|
||||
return []
|
||||
|
||||
async def afile_delete(
|
||||
self, custom_llm_provider: str, file_id: str, **data: Dict
|
||||
) -> OpenAIFileObject:
|
||||
raise NotImplementedError("afile_delete not implemented")
|
||||
|
|
|
@ -311,6 +311,13 @@ async def create_file(
|
|||
)
|
||||
)
|
||||
|
||||
## POST CALL HOOKS ###
|
||||
_response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
if _response is not None and isinstance(_response, OpenAIFileObject):
|
||||
response = _response
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
|
@ -539,6 +546,8 @@ async def get_file(
|
|||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
## check if file_id is a litellm managed file
|
||||
response = await litellm.afile_retrieve(
|
||||
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
||||
)
|
||||
|
|
|
@ -85,7 +85,7 @@ from litellm.proxy.hooks.parallel_request_limiter import (
|
|||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
|
||||
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
|
||||
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -940,7 +940,7 @@ class ProxyLogging:
|
|||
async def post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
response: LLMResponseTypes,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""
|
||||
|
@ -948,6 +948,9 @@ class ProxyLogging:
|
|||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
3. /image/generation
|
||||
4. /files
|
||||
"""
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
|
|
|
@ -34,6 +34,7 @@ from .llms.openai import (
|
|||
ChatCompletionUsageBlock,
|
||||
FileSearchTool,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIFileObject,
|
||||
OpenAIRealtimeStreamList,
|
||||
WebSearchOptions,
|
||||
)
|
||||
|
@ -2226,3 +2227,8 @@ class ExtractedFileData(TypedDict):
|
|||
class SpecialEnums(Enum):
|
||||
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy"
|
||||
LITELLM_MANAGED_FILE_COMPLETE_STR = "litellm_proxy:{};unified_id,{}"
|
||||
|
||||
|
||||
LLMResponseTypes = Union[
|
||||
ModelResponse, EmbeddingResponse, ImageResponse, OpenAIFileObject
|
||||
]
|
||||
|
|
|
@ -43,3 +43,108 @@ def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages():
|
|||
assert messages[0]["content"][1]["file"]["file_id"].startswith(
|
||||
SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
|
||||
)
|
||||
|
||||
|
||||
# def test_list_managed_files():
|
||||
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||
|
||||
# # Create some test files
|
||||
# file1 = proxy_managed_files.create_file(
|
||||
# file=("test1.txt", b"test content 1", "text/plain"),
|
||||
# purpose="assistants"
|
||||
# )
|
||||
# file2 = proxy_managed_files.create_file(
|
||||
# file=("test2.pdf", b"test content 2", "application/pdf"),
|
||||
# purpose="assistants"
|
||||
# )
|
||||
|
||||
# # List all files
|
||||
# files = proxy_managed_files.list_files()
|
||||
|
||||
# # Verify response
|
||||
# assert len(files) == 2
|
||||
# assert all(f.id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value) for f in files)
|
||||
# assert any(f.filename == "test1.txt" for f in files)
|
||||
# assert any(f.filename == "test2.pdf" for f in files)
|
||||
# assert all(f.purpose == "assistants" for f in files)
|
||||
|
||||
# def test_retrieve_managed_file():
|
||||
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||
|
||||
# # Create a test file
|
||||
# test_content = b"test content for retrieve"
|
||||
# created_file = proxy_managed_files.create_file(
|
||||
# file=("test.txt", test_content, "text/plain"),
|
||||
# purpose="assistants"
|
||||
# )
|
||||
|
||||
# # Retrieve the file
|
||||
# retrieved_file = proxy_managed_files.retrieve_file(created_file.id)
|
||||
|
||||
# # Verify response
|
||||
# assert retrieved_file.id == created_file.id
|
||||
# assert retrieved_file.filename == "test.txt"
|
||||
# assert retrieved_file.purpose == "assistants"
|
||||
# assert retrieved_file.bytes == len(test_content)
|
||||
# assert retrieved_file.status == "uploaded"
|
||||
|
||||
# def test_delete_managed_file():
|
||||
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||
|
||||
# # Create a test file
|
||||
# created_file = proxy_managed_files.create_file(
|
||||
# file=("test.txt", b"test content", "text/plain"),
|
||||
# purpose="assistants"
|
||||
# )
|
||||
|
||||
# # Delete the file
|
||||
# deleted_file = proxy_managed_files.delete_file(created_file.id)
|
||||
|
||||
# # Verify deletion
|
||||
# assert deleted_file.id == created_file.id
|
||||
# assert deleted_file.deleted == True
|
||||
|
||||
# # Verify file is no longer retrievable
|
||||
# with pytest.raises(Exception):
|
||||
# proxy_managed_files.retrieve_file(created_file.id)
|
||||
|
||||
# # Verify file is not in list
|
||||
# files = proxy_managed_files.list_files()
|
||||
# assert created_file.id not in [f.id for f in files]
|
||||
|
||||
# def test_retrieve_nonexistent_file():
|
||||
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||
|
||||
# # Try to retrieve a non-existent file
|
||||
# with pytest.raises(Exception):
|
||||
# proxy_managed_files.retrieve_file("nonexistent-file-id")
|
||||
|
||||
# def test_delete_nonexistent_file():
|
||||
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||
|
||||
# # Try to delete a non-existent file
|
||||
# with pytest.raises(Exception):
|
||||
# proxy_managed_files.delete_file("nonexistent-file-id")
|
||||
|
||||
# def test_list_files_with_purpose_filter():
|
||||
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||
|
||||
# # Create files with different purposes
|
||||
# file1 = proxy_managed_files.create_file(
|
||||
# file=("test1.txt", b"test content 1", "text/plain"),
|
||||
# purpose="assistants"
|
||||
# )
|
||||
# file2 = proxy_managed_files.create_file(
|
||||
# file=("test2.pdf", b"test content 2", "application/pdf"),
|
||||
# purpose="batch"
|
||||
# )
|
||||
|
||||
# # List files with purpose filter
|
||||
# assistant_files = proxy_managed_files.list_files(purpose="assistants")
|
||||
# batch_files = proxy_managed_files.list_files(purpose="batch")
|
||||
|
||||
# # Verify filtering
|
||||
# assert len(assistant_files) == 1
|
||||
# assert len(batch_files) == 1
|
||||
# assert assistant_files[0].id == file1.id
|
||||
# assert batch_files[0].id == file2.id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue