mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
LiteLLM Minor Fixes & Improvements (04/02/2025) (#9725)
* Add date picker to usage tab + Add reasoning_content token tracking across all providers on streaming (#9722) * feat(new_usage.tsx): add date picker for new usage tab allow user to look back on their usage data * feat(anthropic/chat/transformation.py): report reasoning tokens in completion token details allows usage tracking on how many reasoning tokens are actually being used * feat(streaming_chunk_builder.py): return reasoning_tokens in anthropic/openai streaming response allows tracking reasoning_token usage across providers * Fix update team metadata + fix bulk adding models on Ui (#9721) * fix(handle_add_model_submit.tsx): fix bulk adding models * fix(team_info.tsx): fix team metadata update Fixes https://github.com/BerriAI/litellm/issues/9689 * (v0) Unified file id - allow calling multiple providers with same file id (#9718) * feat(files_endpoints.py): initial commit adding 'target_model_names' support allow developer to specify all the models they want to call with the file * feat(files_endpoints.py): return unified files endpoint * test(test_files_endpoints.py): add validation test - if invalid purpose submitted * feat: more updates * feat: initial working commit of unified file id translation * fix: additional fixes * fix(router.py): remove model replace logic in jsonl on acreate_file enables file upload to work for chat completion requests as well * fix(files_endpoints.py): remove whitespace around model name * fix(azure/handler.py): return acreate_file with correct response type * fix: fix linting errors * test: fix mock test to run on github actions * fix: fix ruff errors * fix: fix file too large error * fix(utils.py): remove redundant var * test: modify test to work on github actions * test: update tests * test: more debug logs to understand ci/cd issue * test: fix test for respx * test: skip mock respx test fails on ci/cd - not clear why * fix: fix ruff check * fix: fix test * fix(model_connection_test.tsx): fix linting error * test: update unit tests
This commit is contained in:
parent
5a18eebdb6
commit
6dda1ba6dd
27 changed files with 889 additions and 96 deletions
|
@ -63,16 +63,17 @@ async def acreate_file(
|
|||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_file"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
create_file,
|
||||
file,
|
||||
purpose,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
call_args = {
|
||||
"file": file,
|
||||
"purpose": purpose,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"extra_headers": extra_headers,
|
||||
"extra_body": extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
}
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(create_file, **call_args)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
|
@ -92,7 +93,7 @@ async def acreate_file(
|
|||
def create_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -101,6 +102,8 @@ def create_file(
|
|||
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
|
||||
Specify either provider_list or custom_llm_provider.
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_file", False) is True
|
||||
|
@ -120,7 +123,7 @@ def create_file(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
|
|
@ -457,8 +457,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
prompt_management_logger: Optional[CustomLogger] = None,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
custom_logger = self.get_custom_logger_for_prompt_management(model)
|
||||
custom_logger = (
|
||||
prompt_management_logger
|
||||
or self.get_custom_logger_for_prompt_management(model)
|
||||
)
|
||||
if custom_logger:
|
||||
(
|
||||
model,
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict, List, Literal, Optional, Union, cast
|
|||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionFileObject,
|
||||
ChatCompletionUserMessage,
|
||||
)
|
||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||
|
@ -292,3 +293,58 @@ def get_completion_messages(
|
|||
messages, assistant_continue_message, ensure_alternating_roles
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
|
||||
"""
|
||||
Gets file ids from messages
|
||||
"""
|
||||
file_ids = []
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
continue
|
||||
for c in content:
|
||||
if c["type"] == "file":
|
||||
file_object = cast(ChatCompletionFileObject, c)
|
||||
file_object_file_field = file_object["file"]
|
||||
file_id = file_object_file_field.get("file_id")
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
return file_ids
|
||||
|
||||
|
||||
def update_messages_with_model_file_ids(
|
||||
messages: List[AllMessageValues],
|
||||
model_id: str,
|
||||
model_file_id_mapping: Dict[str, Dict[str, str]],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Updates messages with model file ids.
|
||||
|
||||
model_file_id_mapping: Dict[str, Dict[str, str]] = {
|
||||
"litellm_proxy/file_id": {
|
||||
"model_id": "provider_file_id"
|
||||
}
|
||||
}
|
||||
"""
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
continue
|
||||
for c in content:
|
||||
if c["type"] == "file":
|
||||
file_object = cast(ChatCompletionFileObject, c)
|
||||
file_object_file_field = file_object["file"]
|
||||
file_id = file_object_file_field.get("file_id")
|
||||
if file_id:
|
||||
provider_file_id = (
|
||||
model_file_id_mapping.get(file_id, {}).get(model_id)
|
||||
or file_id
|
||||
)
|
||||
file_object_file_field["file_id"] = provider_file_id
|
||||
return messages
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import base64
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionAssistantContentValue,
|
||||
|
@ -9,7 +9,9 @@ from litellm.types.llms.openai import (
|
|||
from litellm.types.utils import (
|
||||
ChatCompletionAudioResponse,
|
||||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
CompletionTokensDetails,
|
||||
CompletionTokensDetailsWrapper,
|
||||
Function,
|
||||
FunctionCall,
|
||||
ModelResponse,
|
||||
|
@ -203,14 +205,14 @@ class ChunkProcessor:
|
|||
)
|
||||
|
||||
def get_combined_content(
|
||||
self, chunks: List[Dict[str, Any]]
|
||||
self, chunks: List[Dict[str, Any]], delta_key: str = "content"
|
||||
) -> ChatCompletionAssistantContentValue:
|
||||
content_list: List[str] = []
|
||||
for chunk in chunks:
|
||||
choices = chunk["choices"]
|
||||
for choice in choices:
|
||||
delta = choice.get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
content = delta.get(delta_key, "")
|
||||
if content is None:
|
||||
continue # openai v1.0.0 sets content = None for chunks
|
||||
content_list.append(content)
|
||||
|
@ -221,6 +223,11 @@ class ChunkProcessor:
|
|||
# Update the "content" field within the response dictionary
|
||||
return combined_content
|
||||
|
||||
def get_combined_reasoning_content(
|
||||
self, chunks: List[Dict[str, Any]]
|
||||
) -> ChatCompletionAssistantContentValue:
|
||||
return self.get_combined_content(chunks, delta_key="reasoning_content")
|
||||
|
||||
def get_combined_audio_content(
|
||||
self, chunks: List[Dict[str, Any]]
|
||||
) -> ChatCompletionAudioResponse:
|
||||
|
@ -296,12 +303,27 @@ class ChunkProcessor:
|
|||
"prompt_tokens_details": prompt_tokens_details,
|
||||
}
|
||||
|
||||
def count_reasoning_tokens(self, response: ModelResponse) -> int:
|
||||
reasoning_tokens = 0
|
||||
for choice in response.choices:
|
||||
if (
|
||||
hasattr(cast(Choices, choice).message, "reasoning_content")
|
||||
and cast(Choices, choice).message.reasoning_content is not None
|
||||
):
|
||||
reasoning_tokens += token_counter(
|
||||
text=cast(Choices, choice).message.reasoning_content,
|
||||
count_response_tokens=True,
|
||||
)
|
||||
|
||||
return reasoning_tokens
|
||||
|
||||
def calculate_usage(
|
||||
self,
|
||||
chunks: List[Union[Dict[str, Any], ModelResponse]],
|
||||
model: str,
|
||||
completion_output: str,
|
||||
messages: Optional[List] = None,
|
||||
reasoning_tokens: Optional[int] = None,
|
||||
) -> Usage:
|
||||
"""
|
||||
Calculate usage for the given chunks.
|
||||
|
@ -382,6 +404,19 @@ class ChunkProcessor:
|
|||
) # for anthropic
|
||||
if completion_tokens_details is not None:
|
||||
returned_usage.completion_tokens_details = completion_tokens_details
|
||||
|
||||
if reasoning_tokens is not None:
|
||||
if returned_usage.completion_tokens_details is None:
|
||||
returned_usage.completion_tokens_details = (
|
||||
CompletionTokensDetailsWrapper(reasoning_tokens=reasoning_tokens)
|
||||
)
|
||||
elif (
|
||||
returned_usage.completion_tokens_details is not None
|
||||
and returned_usage.completion_tokens_details.reasoning_tokens is None
|
||||
):
|
||||
returned_usage.completion_tokens_details.reasoning_tokens = (
|
||||
reasoning_tokens
|
||||
)
|
||||
if prompt_tokens_details is not None:
|
||||
returned_usage.prompt_tokens_details = prompt_tokens_details
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicChatCompletionUsageBlock,
|
||||
ContentBlockDelta,
|
||||
ContentBlockStart,
|
||||
ContentBlockStop,
|
||||
|
@ -32,13 +31,13 @@ from litellm.types.llms.anthropic import (
|
|||
from litellm.types.llms.openai import (
|
||||
ChatCompletionThinkingBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
GenericStreamingChunk,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
Usage,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||
|
||||
|
@ -487,10 +486,8 @@ class ModelResponseIterator:
|
|||
return True
|
||||
return False
|
||||
|
||||
def _handle_usage(
|
||||
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
||||
) -> AnthropicChatCompletionUsageBlock:
|
||||
usage_block = AnthropicChatCompletionUsageBlock(
|
||||
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
||||
usage_block = Usage(
|
||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
||||
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
||||
|
@ -581,7 +578,7 @@ class ModelResponseIterator:
|
|||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
usage: Optional[Usage] = None
|
||||
provider_specific_fields: Dict[str, Any] = {}
|
||||
reasoning_content: Optional[str] = None
|
||||
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
||||
|
|
|
@ -33,9 +33,16 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.utils import CompletionTokensDetailsWrapper
|
||||
from litellm.types.utils import Message as LitellmMessage
|
||||
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||
from litellm.utils import ModelResponse, Usage, add_dummy_tool, has_tool_call_blocks
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Usage,
|
||||
add_dummy_tool,
|
||||
has_tool_call_blocks,
|
||||
token_counter,
|
||||
)
|
||||
|
||||
from ..common_utils import AnthropicError, process_anthropic_headers
|
||||
|
||||
|
@ -772,6 +779,15 @@ class AnthropicConfig(BaseConfig):
|
|||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
completion_token_details = (
|
||||
CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=token_counter(
|
||||
text=reasoning_content, count_response_tokens=True
|
||||
)
|
||||
)
|
||||
if reasoning_content
|
||||
else None
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
@ -780,6 +796,7 @@ class AnthropicConfig(BaseConfig):
|
|||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
completion_tokens_details=completion_token_details,
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
|
|
|
@ -28,11 +28,11 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
|||
self,
|
||||
create_file_data: CreateFileRequest,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> FileObject:
|
||||
) -> OpenAIFileObject:
|
||||
verbose_logger.debug("create_file_data=%s", create_file_data)
|
||||
response = await openai_client.files.create(**create_file_data)
|
||||
verbose_logger.debug("create_file_response=%s", response)
|
||||
return response
|
||||
return OpenAIFileObject(**response.model_dump())
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
|
@ -66,7 +66,7 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
|||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.acreate_file( # type: ignore
|
||||
return self.acreate_file(
|
||||
create_file_data=create_file_data, openai_client=openai_client
|
||||
)
|
||||
response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)
|
||||
|
|
|
@ -110,7 +110,10 @@ from .litellm_core_utils.fallback_utils import (
|
|||
async_completion_with_fallbacks,
|
||||
completion_with_fallbacks,
|
||||
)
|
||||
from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages
|
||||
from .litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_completion_messages,
|
||||
update_messages_with_model_file_ids,
|
||||
)
|
||||
from .litellm_core_utils.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
function_call_prompt,
|
||||
|
@ -953,7 +956,6 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
||||
litellm_params = {} # used to prevent unbound var errors
|
||||
## PROMPT MANAGEMENT HOOKS ##
|
||||
|
||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
|
||||
(
|
||||
model,
|
||||
|
@ -1068,6 +1070,15 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
if eos_token:
|
||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||
|
||||
if kwargs.get("model_file_id_mapping"):
|
||||
messages = update_messages_with_model_file_ids(
|
||||
messages=messages,
|
||||
model_id=kwargs.get("model_info", {}).get("id", None),
|
||||
model_file_id_mapping=cast(
|
||||
Dict[str, Dict[str, str]], kwargs.get("model_file_id_mapping")
|
||||
),
|
||||
)
|
||||
|
||||
provider_config: Optional[BaseConfig] = None
|
||||
if custom_llm_provider is not None and custom_llm_provider in [
|
||||
provider.value for provider in LlmProviders
|
||||
|
@ -5799,6 +5810,19 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
"content"
|
||||
] = processor.get_combined_content(content_chunks)
|
||||
|
||||
reasoning_chunks = [
|
||||
chunk
|
||||
for chunk in chunks
|
||||
if len(chunk["choices"]) > 0
|
||||
and "reasoning_content" in chunk["choices"][0]["delta"]
|
||||
and chunk["choices"][0]["delta"]["reasoning_content"] is not None
|
||||
]
|
||||
|
||||
if len(reasoning_chunks) > 0:
|
||||
response["choices"][0]["message"][
|
||||
"reasoning_content"
|
||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
|
||||
audio_chunks = [
|
||||
chunk
|
||||
for chunk in chunks
|
||||
|
@ -5813,11 +5837,14 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
|
||||
completion_output = get_content_from_model_response(response)
|
||||
|
||||
reasoning_tokens = processor.count_reasoning_tokens(response)
|
||||
|
||||
usage = processor.calculate_usage(
|
||||
chunks=chunks,
|
||||
model=model,
|
||||
completion_output=completion_output,
|
||||
messages=messages,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
setattr(response, "usage", usage)
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
model_list:
|
||||
- model_name: "gpt-4o"
|
||||
- model_name: "gpt-4o-azure"
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
model: azure/gpt-4o
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: http://0.0.0.0:8090
|
||||
rpm: 3
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
- model_name: "gpt-4o-mini-openai"
|
||||
litellm_params:
|
||||
model: gpt-4o-mini
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: "openai/*"
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: "bedrock-nova"
|
||||
litellm_params:
|
||||
model: us.amazon.nova-pro-v1:0
|
||||
|
|
|
@ -2688,6 +2688,10 @@ class PrismaCompatibleUpdateDBModel(TypedDict, total=False):
|
|||
updated_by: str
|
||||
|
||||
|
||||
class SpecialEnums(enum.Enum):
|
||||
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy/"
|
||||
|
||||
|
||||
class SpecialManagementEndpointEnums(enum.Enum):
|
||||
DEFAULT_ORGANIZATION = "default_organization"
|
||||
|
||||
|
|
|
@ -1 +1,36 @@
|
|||
from typing import Literal, Union
|
||||
|
||||
from . import *
|
||||
from .cache_control_check import _PROXY_CacheControlCheck
|
||||
from .managed_files import _PROXY_LiteLLMManagedFiles
|
||||
from .max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||
from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler
|
||||
|
||||
# List of all available hooks that can be enabled
|
||||
PROXY_HOOKS = {
|
||||
"max_budget_limiter": _PROXY_MaxBudgetLimiter,
|
||||
"managed_files": _PROXY_LiteLLMManagedFiles,
|
||||
"parallel_request_limiter": _PROXY_MaxParallelRequestsHandler,
|
||||
"cache_control_check": _PROXY_CacheControlCheck,
|
||||
}
|
||||
|
||||
|
||||
def get_proxy_hook(
|
||||
hook_name: Union[
|
||||
Literal[
|
||||
"max_budget_limiter",
|
||||
"managed_files",
|
||||
"parallel_request_limiter",
|
||||
"cache_control_check",
|
||||
],
|
||||
str,
|
||||
]
|
||||
):
|
||||
"""
|
||||
Factory method to get a proxy hook instance by name
|
||||
"""
|
||||
if hook_name not in PROXY_HOOKS:
|
||||
raise ValueError(
|
||||
f"Unknown hook: {hook_name}. Available hooks: {list(PROXY_HOOKS.keys())}"
|
||||
)
|
||||
return PROXY_HOOKS[hook_name]
|
||||
|
|
145
litellm/proxy/hooks/managed_files.py
Normal file
145
litellm/proxy/hooks/managed_files.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
# What is this?
|
||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_file_ids_from_messages,
|
||||
)
|
||||
from litellm.proxy._types import CallTypes, SpecialEnums, UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
|
||||
|
||||
class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: Dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
],
|
||||
) -> Union[Exception, str, Dict, None]:
|
||||
"""
|
||||
- Detect litellm_proxy/ file_id
|
||||
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
|
||||
"""
|
||||
if call_type == CallTypes.completion.value:
|
||||
messages = data.get("messages")
|
||||
if messages:
|
||||
file_ids = get_file_ids_from_messages(messages)
|
||||
if file_ids:
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
file_ids, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
data["model_file_id_mapping"] = model_file_id_mapping
|
||||
|
||||
return data
|
||||
|
||||
async def get_model_file_id_mapping(
|
||||
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||
) -> dict:
|
||||
"""
|
||||
Get model-specific file IDs for a list of proxy file IDs.
|
||||
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
|
||||
|
||||
1. Get all the litellm_proxy/ file_ids from the messages
|
||||
2. For each file_id, search for cache keys matching the pattern file_id:*
|
||||
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
|
||||
|
||||
Example:
|
||||
{
|
||||
"litellm_proxy/file_id": {
|
||||
"model_id": "model_file_id"
|
||||
}
|
||||
}
|
||||
"""
|
||||
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||
litellm_managed_file_ids = []
|
||||
|
||||
for file_id in file_ids:
|
||||
## CHECK IF FILE ID IS MANAGED BY LITELM
|
||||
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
litellm_managed_file_ids.append(file_id)
|
||||
|
||||
if litellm_managed_file_ids:
|
||||
# Get all cache keys matching the pattern file_id:*
|
||||
for file_id in litellm_managed_file_ids:
|
||||
# Search for any cache key starting with this file_id
|
||||
cached_values = cast(
|
||||
Dict[str, str],
|
||||
await self.internal_usage_cache.async_get_cache(
|
||||
key=file_id, litellm_parent_otel_span=litellm_parent_otel_span
|
||||
),
|
||||
)
|
||||
if cached_values:
|
||||
file_id_mapping[file_id] = cached_values
|
||||
return file_id_mapping
|
||||
|
||||
@staticmethod
|
||||
async def return_unified_file_id(
|
||||
file_objects: List[OpenAIFileObject],
|
||||
purpose: OpenAIFilesPurpose,
|
||||
internal_usage_cache: InternalUsageCache,
|
||||
litellm_parent_otel_span: Span,
|
||||
) -> OpenAIFileObject:
|
||||
unified_file_id = SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value + str(
|
||||
uuid.uuid4()
|
||||
)
|
||||
|
||||
## CREATE RESPONSE OBJECT
|
||||
response = OpenAIFileObject(
|
||||
id=unified_file_id,
|
||||
object="file",
|
||||
purpose=cast(OpenAIFilesPurpose, purpose),
|
||||
created_at=file_objects[0].created_at,
|
||||
bytes=1234,
|
||||
filename=str(datetime.now().timestamp()),
|
||||
status="uploaded",
|
||||
)
|
||||
|
||||
## STORE RESPONSE IN DB + CACHE
|
||||
stored_values: Dict[str, str] = {}
|
||||
for file_object in file_objects:
|
||||
model_id = file_object._hidden_params.get("model_id")
|
||||
if model_id is None:
|
||||
verbose_logger.warning(
|
||||
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
|
||||
)
|
||||
continue
|
||||
file_id = file_object.id
|
||||
stored_values[model_id] = file_id
|
||||
await internal_usage_cache.async_set_cache(
|
||||
key=unified_file_id,
|
||||
value=stored_values,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
return response
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from typing import Optional, cast, get_args
|
||||
|
||||
import httpx
|
||||
from fastapi import (
|
||||
|
@ -31,7 +31,10 @@ from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessin
|
|||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_body,
|
||||
)
|
||||
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.router import Router
|
||||
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -104,6 +107,53 @@ def is_known_model(model: Optional[str], llm_router: Optional[Router]) -> bool:
|
|||
return is_in_list
|
||||
|
||||
|
||||
async def _deprecated_loadbalanced_create_file(
|
||||
llm_router: Optional[Router],
|
||||
router_model: str,
|
||||
_create_file_request: CreateFileRequest,
|
||||
) -> OpenAIFileObject:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.acreate_file(model=router_model, **_create_file_request)
|
||||
return response
|
||||
|
||||
|
||||
async def create_file_for_each_model(
|
||||
llm_router: Optional[Router],
|
||||
_create_file_request: CreateFileRequest,
|
||||
target_model_names_list: List[str],
|
||||
purpose: OpenAIFilesPurpose,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> OpenAIFileObject:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
responses = []
|
||||
for model in target_model_names_list:
|
||||
individual_response = await llm_router.acreate_file(
|
||||
model=model, **_create_file_request
|
||||
)
|
||||
responses.append(individual_response)
|
||||
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
|
||||
file_objects=responses,
|
||||
purpose=purpose,
|
||||
internal_usage_cache=proxy_logging_obj.internal_usage_cache,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/v1/files",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
@ -123,6 +173,7 @@ async def create_file(
|
|||
request: Request,
|
||||
fastapi_response: Response,
|
||||
purpose: str = Form(...),
|
||||
target_model_names: str = Form(default=""),
|
||||
provider: Optional[str] = None,
|
||||
custom_llm_provider: str = Form(default="openai"),
|
||||
file: UploadFile = File(...),
|
||||
|
@ -162,8 +213,25 @@ async def create_file(
|
|||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
)
|
||||
|
||||
target_model_names_list = (
|
||||
target_model_names.split(",") if target_model_names else []
|
||||
)
|
||||
target_model_names_list = [model.strip() for model in target_model_names_list]
|
||||
# Prepare the data for forwarding
|
||||
|
||||
# Replace with:
|
||||
valid_purposes = get_args(OpenAIFilesPurpose)
|
||||
if purpose not in valid_purposes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Invalid purpose: {purpose}. Must be one of: {valid_purposes}",
|
||||
},
|
||||
)
|
||||
# Cast purpose to OpenAIFilesPurpose type
|
||||
purpose = cast(OpenAIFilesPurpose, purpose)
|
||||
|
||||
data = {"purpose": purpose}
|
||||
|
||||
# Include original request and headers in the data
|
||||
|
@ -192,21 +260,25 @@ async def create_file(
|
|||
|
||||
_create_file_request = CreateFileRequest(file=file_data, **data)
|
||||
|
||||
response: Optional[OpenAIFileObject] = None
|
||||
if (
|
||||
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||
and is_router_model
|
||||
and router_model is not None
|
||||
):
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.acreate_file(
|
||||
model=router_model, **_create_file_request
|
||||
response = await _deprecated_loadbalanced_create_file(
|
||||
llm_router=llm_router,
|
||||
router_model=router_model,
|
||||
_create_file_request=_create_file_request,
|
||||
)
|
||||
elif target_model_names_list:
|
||||
response = await create_file_for_each_model(
|
||||
llm_router=llm_router,
|
||||
_create_file_request=_create_file_request,
|
||||
target_model_names_list=target_model_names_list,
|
||||
purpose=purpose,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
else:
|
||||
# get configs for custom_llm_provider
|
||||
|
@ -220,6 +292,11 @@ async def create_file(
|
|||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
|
||||
|
||||
if response is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Failed to create file. Please try again."},
|
||||
)
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
|
@ -248,12 +325,11 @@ async def create_file(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
|
|
|
@ -76,6 +76,7 @@ from litellm.proxy.db.create_views import (
|
|||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
||||
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||
from litellm.proxy.hooks import PROXY_HOOKS, get_proxy_hook
|
||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||
|
@ -352,10 +353,19 @@ class ProxyLogging:
|
|||
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
|
||||
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
|
||||
|
||||
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
|
||||
for hook in PROXY_HOOKS:
|
||||
proxy_hook = get_proxy_hook(hook)
|
||||
import inspect
|
||||
|
||||
expected_args = inspect.getfullargspec(proxy_hook).args
|
||||
if "internal_usage_cache" in expected_args:
|
||||
litellm.logging_callback_manager.add_litellm_callback(proxy_hook(self.internal_usage_cache)) # type: ignore
|
||||
else:
|
||||
litellm.logging_callback_manager.add_litellm_callback(proxy_hook()) # type: ignore
|
||||
|
||||
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore
|
||||
self._add_proxy_hooks(llm_router)
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
|
|
|
@ -68,10 +68,7 @@ from litellm.router_utils.add_retry_fallback_headers import (
|
|||
add_fallback_headers_to_response,
|
||||
add_retry_headers_to_response,
|
||||
)
|
||||
from litellm.router_utils.batch_utils import (
|
||||
_get_router_metadata_variable_name,
|
||||
replace_model_in_jsonl,
|
||||
)
|
||||
from litellm.router_utils.batch_utils import _get_router_metadata_variable_name
|
||||
from litellm.router_utils.client_initalization_utils import InitalizeCachedClient
|
||||
from litellm.router_utils.clientside_credential_handler import (
|
||||
get_dynamic_litellm_params,
|
||||
|
@ -105,7 +102,12 @@ from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
|||
increment_deployment_successes_for_current_minute,
|
||||
)
|
||||
from litellm.scheduler import FlowItem, Scheduler
|
||||
from litellm.types.llms.openai import AllMessageValues, Batch, FileObject, FileTypes
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
Batch,
|
||||
FileTypes,
|
||||
OpenAIFileObject,
|
||||
)
|
||||
from litellm.types.router import (
|
||||
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||
VALID_LITELLM_ENVIRONMENTS,
|
||||
|
@ -2703,7 +2705,7 @@ class Router:
|
|||
self,
|
||||
model: str,
|
||||
**kwargs,
|
||||
) -> FileObject:
|
||||
) -> OpenAIFileObject:
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["original_function"] = self._acreate_file
|
||||
|
@ -2727,7 +2729,7 @@ class Router:
|
|||
self,
|
||||
model: str,
|
||||
**kwargs,
|
||||
) -> FileObject:
|
||||
) -> OpenAIFileObject:
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||
|
@ -2754,9 +2756,9 @@ class Router:
|
|||
stripped_model, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=data["model"]
|
||||
)
|
||||
kwargs["file"] = replace_model_in_jsonl(
|
||||
file_content=kwargs["file"], new_model_name=stripped_model
|
||||
)
|
||||
# kwargs["file"] = replace_model_in_jsonl(
|
||||
# file_content=kwargs["file"], new_model_name=stripped_model
|
||||
# )
|
||||
|
||||
response = litellm.acreate_file(
|
||||
**{
|
||||
|
@ -2796,6 +2798,7 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
|
||||
return response # type: ignore
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
|
|
14
litellm/router_utils/common_utils.py
Normal file
14
litellm/router_utils/common_utils.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import hashlib
|
||||
import json
|
||||
|
||||
from litellm.types.router import CredentialLiteLLMParams
|
||||
|
||||
|
||||
def get_litellm_params_sensitive_credential_hash(litellm_params: dict) -> str:
|
||||
"""
|
||||
Hash of the credential params, used for mapping the file id to the right model
|
||||
"""
|
||||
sensitive_params = CredentialLiteLLMParams(**litellm_params)
|
||||
return hashlib.sha256(
|
||||
json.dumps(sensitive_params.model_dump()).encode()
|
||||
).hexdigest()
|
|
@ -234,7 +234,18 @@ class Thread(BaseModel):
|
|||
"""The object type, which is always `thread`."""
|
||||
|
||||
|
||||
OpenAICreateFileRequestOptionalParams = Literal["purpose",]
|
||||
OpenAICreateFileRequestOptionalParams = Literal["purpose"]
|
||||
|
||||
OpenAIFilesPurpose = Literal[
|
||||
"assistants",
|
||||
"assistants_output",
|
||||
"batch",
|
||||
"batch_output",
|
||||
"fine-tune",
|
||||
"fine-tune-results",
|
||||
"vision",
|
||||
"user_data",
|
||||
]
|
||||
|
||||
|
||||
class OpenAIFileObject(BaseModel):
|
||||
|
@ -253,16 +264,7 @@ class OpenAIFileObject(BaseModel):
|
|||
object: Literal["file"]
|
||||
"""The object type, which is always `file`."""
|
||||
|
||||
purpose: Literal[
|
||||
"assistants",
|
||||
"assistants_output",
|
||||
"batch",
|
||||
"batch_output",
|
||||
"fine-tune",
|
||||
"fine-tune-results",
|
||||
"vision",
|
||||
"user_data",
|
||||
]
|
||||
purpose: OpenAIFilesPurpose
|
||||
"""The intended purpose of the file.
|
||||
|
||||
Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`,
|
||||
|
@ -286,6 +288,8 @@ class OpenAIFileObject(BaseModel):
|
|||
`error` field on `fine_tuning.job`.
|
||||
"""
|
||||
|
||||
_hidden_params: dict = {}
|
||||
|
||||
|
||||
# OpenAI Files Types
|
||||
class CreateFileRequest(TypedDict, total=False):
|
||||
|
|
|
@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|||
from ..exceptions import RateLimitError
|
||||
from .completion import CompletionRequest
|
||||
from .embedding import EmbeddingRequest
|
||||
from .llms.openai import OpenAIFileObject
|
||||
from .llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
from .utils import ModelResponse, ProviderSpecificModelInfo
|
||||
|
||||
|
@ -703,3 +704,12 @@ class GenericBudgetWindowDetails(BaseModel):
|
|||
|
||||
|
||||
OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]]
|
||||
|
||||
|
||||
class LiteLLM_RouterFileObject(TypedDict, total=False):
|
||||
"""
|
||||
Tracking the litellm params hash, used for mapping the file id to the right model
|
||||
"""
|
||||
|
||||
litellm_params_sensitive_credential_hash: str
|
||||
file_object: OpenAIFileObject
|
||||
|
|
|
@ -1886,6 +1886,7 @@ all_litellm_params = [
|
|||
"logger_fn",
|
||||
"verbose",
|
||||
"custom_llm_provider",
|
||||
"model_file_id_mapping",
|
||||
"litellm_logging_obj",
|
||||
"litellm_call_id",
|
||||
"use_client",
|
||||
|
|
|
@ -17,6 +17,7 @@ import litellm
|
|||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.types.utils import (
|
||||
CompletionTokensDetailsWrapper,
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
PromptTokensDetailsWrapper,
|
||||
|
@ -430,11 +431,18 @@ async def test_streaming_handler_with_usage(
|
|||
completion_tokens=392,
|
||||
prompt_tokens=1799,
|
||||
total_tokens=2191,
|
||||
completion_tokens_details=None,
|
||||
completion_tokens_details=CompletionTokensDetailsWrapper( # <-- This has a value
|
||||
accepted_prediction_tokens=None,
|
||||
audio_tokens=None,
|
||||
reasoning_tokens=0,
|
||||
rejected_prediction_tokens=None,
|
||||
text_tokens=None,
|
||||
),
|
||||
prompt_tokens_details=PromptTokensDetailsWrapper(
|
||||
audio_tokens=None, cached_tokens=1796, text_tokens=None, image_tokens=None
|
||||
),
|
||||
)
|
||||
|
||||
final_chunk = ModelResponseStream(
|
||||
id="chatcmpl-87291500-d8c5-428e-b187-36fe5a4c97ab",
|
||||
created=1742056047,
|
||||
|
@ -510,7 +518,13 @@ async def test_streaming_with_usage_and_logging(sync_mode: bool):
|
|||
completion_tokens=392,
|
||||
prompt_tokens=1799,
|
||||
total_tokens=2191,
|
||||
completion_tokens_details=None,
|
||||
completion_tokens_details=CompletionTokensDetailsWrapper(
|
||||
accepted_prediction_tokens=None,
|
||||
audio_tokens=None,
|
||||
reasoning_tokens=0,
|
||||
rejected_prediction_tokens=None,
|
||||
text_tokens=None,
|
||||
),
|
||||
prompt_tokens_details=PromptTokensDetailsWrapper(
|
||||
audio_tokens=None,
|
||||
cached_tokens=1796,
|
||||
|
|
306
tests/litellm/proxy/openai_files_endpoint/test_files_endpoint.py
Normal file
306
tests/litellm/proxy/openai_files_endpoint/test_files_endpoint.py
Normal file
|
@ -0,0 +1,306 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.proxy._types import LiteLLM_UserTableFiltered, UserAPIKeyAuth
|
||||
from litellm.proxy.hooks import get_proxy_hook
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import ui_view_users
|
||||
from litellm.proxy.proxy_server import app
|
||||
|
||||
client = TestClient(app)
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy.proxy_server import hash_token
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_router() -> Router:
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "azure-gpt-3-5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "azure_api_key",
|
||||
"api_base": "azure_api_base",
|
||||
"api_version": "azure_api_version",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "azure-gpt-3-5-turbo-id",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-3.5-turbo",
|
||||
"api_key": "openai_api_key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "gpt-3.5-turbo-id",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-2.0-flash",
|
||||
"litellm_params": {
|
||||
"model": "gemini/gemini-2.0-flash",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "gemini-2.0-flash-id",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
return llm_router
|
||||
|
||||
|
||||
def setup_proxy_logging_object(monkeypatch, llm_router: Router) -> ProxyLogging:
|
||||
proxy_logging_object = ProxyLogging(
|
||||
user_api_key_cache=DualCache(default_in_memory_ttl=1)
|
||||
)
|
||||
proxy_logging_object._add_proxy_hooks(llm_router)
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_object
|
||||
)
|
||||
return proxy_logging_object
|
||||
|
||||
|
||||
def test_invalid_purpose(mocker: MockerFixture, monkeypatch, llm_router: Router):
|
||||
"""
|
||||
Asserts 'create_file' is called with the correct arguments
|
||||
"""
|
||||
# Create a simple test file content
|
||||
test_file_content = b"test audio content"
|
||||
test_file = ("test.wav", test_file_content, "audio/wav")
|
||||
|
||||
response = client.post(
|
||||
"/v1/files",
|
||||
files={"file": test_file},
|
||||
data={
|
||||
"purpose": "my-bad-purpose",
|
||||
"target_model_names": ["azure-gpt-3-5-turbo", "gpt-3.5-turbo"],
|
||||
},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
print(f"response: {response.json()}")
|
||||
assert "Invalid purpose: my-bad-purpose" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_mock_create_audio_file(mocker: MockerFixture, monkeypatch, llm_router: Router):
|
||||
"""
|
||||
Asserts 'create_file' is called with the correct arguments
|
||||
"""
|
||||
from litellm import Router
|
||||
|
||||
mock_create_file = mocker.patch("litellm.files.main.create_file")
|
||||
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", llm_router)
|
||||
|
||||
# Create a simple test file content
|
||||
test_file_content = b"test audio content"
|
||||
test_file = ("test.wav", test_file_content, "audio/wav")
|
||||
|
||||
response = client.post(
|
||||
"/v1/files",
|
||||
files={"file": test_file},
|
||||
data={
|
||||
"purpose": "user_data",
|
||||
"target_model_names": "azure-gpt-3-5-turbo, gpt-3.5-turbo",
|
||||
},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
print(f"response: {response.text}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Get all calls made to create_file
|
||||
calls = mock_create_file.call_args_list
|
||||
|
||||
# Check for Azure call
|
||||
azure_call_found = False
|
||||
for call in calls:
|
||||
kwargs = call.kwargs
|
||||
if (
|
||||
kwargs.get("custom_llm_provider") == "azure"
|
||||
and kwargs.get("model") == "azure/chatgpt-v-2"
|
||||
and kwargs.get("api_key") == "azure_api_key"
|
||||
):
|
||||
azure_call_found = True
|
||||
break
|
||||
assert (
|
||||
azure_call_found
|
||||
), f"Azure call not found with expected parameters. Calls: {calls}"
|
||||
|
||||
# Check for OpenAI call
|
||||
openai_call_found = False
|
||||
for call in calls:
|
||||
kwargs = call.kwargs
|
||||
if (
|
||||
kwargs.get("custom_llm_provider") == "openai"
|
||||
and kwargs.get("model") == "openai/gpt-3.5-turbo"
|
||||
and kwargs.get("api_key") == "openai_api_key"
|
||||
):
|
||||
openai_call_found = True
|
||||
break
|
||||
assert openai_call_found, "OpenAI call not found with expected parameters"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="mock respx fails on ci/cd - unclear why")
|
||||
def test_create_file_and_call_chat_completion_e2e(
|
||||
mocker: MockerFixture, monkeypatch, llm_router: Router
|
||||
):
|
||||
"""
|
||||
1. Create a file
|
||||
2. Call a chat completion with the file
|
||||
3. Assert the file is used in the chat completion
|
||||
"""
|
||||
# Create and enable respx mock instance
|
||||
mock = respx.mock()
|
||||
mock.start()
|
||||
try:
|
||||
from litellm.types.llms.openai import OpenAIFileObject
|
||||
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", llm_router)
|
||||
proxy_logging_object = setup_proxy_logging_object(monkeypatch, llm_router)
|
||||
|
||||
# Create a simple test file content
|
||||
test_file_content = b"test audio content"
|
||||
test_file = ("test.wav", test_file_content, "audio/wav")
|
||||
|
||||
# Mock the file creation response
|
||||
mock_file_response = OpenAIFileObject(
|
||||
id="test-file-id",
|
||||
object="file",
|
||||
bytes=123,
|
||||
created_at=1234567890,
|
||||
filename="test.wav",
|
||||
purpose="user_data",
|
||||
status="uploaded",
|
||||
)
|
||||
mock_file_response._hidden_params = {"model_id": "gemini-2.0-flash-id"}
|
||||
mocker.patch.object(llm_router, "acreate_file", return_value=mock_file_response)
|
||||
|
||||
# Mock the Gemini API call using respx
|
||||
mock_gemini_response = {
|
||||
"candidates": [
|
||||
{"content": {"parts": [{"text": "This is a test audio file"}]}}
|
||||
]
|
||||
}
|
||||
|
||||
# Mock the Gemini API endpoint with a more flexible pattern
|
||||
gemini_route = mock.post(
|
||||
url__regex=r".*generativelanguage\.googleapis\.com.*"
|
||||
).mock(
|
||||
return_value=respx.MockResponse(status_code=200, json=mock_gemini_response),
|
||||
)
|
||||
|
||||
# Print updated mock setup
|
||||
print("\nAfter Adding Gemini Route:")
|
||||
print("==========================")
|
||||
print(f"Number of mocked routes: {len(mock.routes)}")
|
||||
for route in mock.routes:
|
||||
print(f"Mocked Route: {route}")
|
||||
print(f"Pattern: {route.pattern}")
|
||||
|
||||
## CREATE FILE
|
||||
file = client.post(
|
||||
"/v1/files",
|
||||
files={"file": test_file},
|
||||
data={
|
||||
"purpose": "user_data",
|
||||
"target_model_names": "gemini-2.0-flash, gpt-3.5-turbo",
|
||||
},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
print("\nAfter File Creation:")
|
||||
print("====================")
|
||||
print(f"File creation status: {file.status_code}")
|
||||
print(f"Recorded calls so far: {len(mock.calls)}")
|
||||
for call in mock.calls:
|
||||
print(f"Call made to: {call.request.method} {call.request.url}")
|
||||
|
||||
assert file.status_code == 200
|
||||
assert file.json()["id"] != "test-file-id" # unified file id used
|
||||
|
||||
## USE FILE IN CHAT COMPLETION
|
||||
try:
|
||||
completion = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "gemini-2.0-flash",
|
||||
"modalities": ["text", "audio"],
|
||||
"audio": {"voice": "alloy", "format": "wav"},
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is in this recording?"},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"file_id": file.json()["id"],
|
||||
"filename": "my-test-name",
|
||||
"format": "audio/wav",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"drop_params": True,
|
||||
},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"error: {e}")
|
||||
|
||||
print("\nError occurred during chat completion:")
|
||||
print("=====================================")
|
||||
print("\nFinal Mock State:")
|
||||
print("=================")
|
||||
print(f"Total mocked routes: {len(mock.routes)}")
|
||||
for route in mock.routes:
|
||||
print(f"\nMocked Route: {route}")
|
||||
print(f" Called: {route.called}")
|
||||
|
||||
print("\nActual Requests Made:")
|
||||
print("=====================")
|
||||
print(f"Total calls recorded: {len(mock.calls)}")
|
||||
for idx, call in enumerate(mock.calls):
|
||||
print(f"\nCall {idx + 1}:")
|
||||
print(f" Method: {call.request.method}")
|
||||
print(f" URL: {call.request.url}")
|
||||
print(f" Headers: {dict(call.request.headers)}")
|
||||
try:
|
||||
print(f" Body: {call.request.content.decode()}")
|
||||
except:
|
||||
print(" Body: <could not decode>")
|
||||
|
||||
# Verify Gemini API was called
|
||||
assert gemini_route.called, "Gemini API was not called"
|
||||
|
||||
# Print the call details
|
||||
print("\nGemini API Call Details:")
|
||||
print(f"URL: {gemini_route.calls.last.request.url}")
|
||||
print(f"Method: {gemini_route.calls.last.request.method}")
|
||||
print(f"Headers: {dict(gemini_route.calls.last.request.headers)}")
|
||||
print(f"Content: {gemini_route.calls.last.request.content.decode()}")
|
||||
print(f"Response: {gemini_route.calls.last.response.content.decode()}")
|
||||
|
||||
assert "test-file-id" in gemini_route.calls.last.request.content.decode()
|
||||
finally:
|
||||
# Stop the mock
|
||||
mock.stop()
|
|
@ -39,7 +39,6 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
|
|||
with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch(
|
||||
"litellm.proxy.proxy_server.store_model_in_db", False
|
||||
): # set store_model_in_db to False
|
||||
|
||||
# Test when store_model_in_db is False
|
||||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||||
general_settings={},
|
||||
|
@ -57,7 +56,6 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
|
|||
with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch(
|
||||
"litellm.proxy.proxy_server.store_model_in_db", True
|
||||
), patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True):
|
||||
|
||||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||||
general_settings={},
|
||||
prisma_client=mock_prisma_client,
|
||||
|
|
|
@ -1116,3 +1116,6 @@ def test_anthropic_thinking_in_assistant_message(model):
|
|||
response = litellm.completion(**params)
|
||||
|
||||
assert response is not None
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ export const prepareModelAddRequest = async (
|
|||
}
|
||||
|
||||
// Create a deployment for each mapping
|
||||
const deployments = [];
|
||||
for (const mapping of modelMappings) {
|
||||
const litellmParamsObj: Record<string, any> = {};
|
||||
const modelInfoObj: Record<string, any> = {};
|
||||
|
@ -142,8 +143,10 @@ export const prepareModelAddRequest = async (
|
|||
}
|
||||
}
|
||||
|
||||
return { litellmParamsObj, modelInfoObj, modelName };
|
||||
deployments.push({ litellmParamsObj, modelInfoObj, modelName });
|
||||
}
|
||||
|
||||
return deployments;
|
||||
} catch (error) {
|
||||
message.error("Failed to create model: " + error, 10);
|
||||
}
|
||||
|
@ -156,22 +159,25 @@ export const handleAddModelSubmit = async (
|
|||
callback?: () => void,
|
||||
) => {
|
||||
try {
|
||||
const result = await prepareModelAddRequest(values, accessToken, form);
|
||||
const deployments = await prepareModelAddRequest(values, accessToken, form);
|
||||
|
||||
if (!result) {
|
||||
return; // Exit if preparation failed
|
||||
if (!deployments || deployments.length === 0) {
|
||||
return; // Exit if preparation failed or no deployments
|
||||
}
|
||||
|
||||
const { litellmParamsObj, modelInfoObj, modelName } = result;
|
||||
|
||||
const new_model: Model = {
|
||||
model_name: modelName,
|
||||
litellm_params: litellmParamsObj,
|
||||
model_info: modelInfoObj,
|
||||
};
|
||||
|
||||
const response: any = await modelCreateCall(accessToken, new_model);
|
||||
console.log(`response for model create call: ${response["data"]}`);
|
||||
// Create each deployment
|
||||
for (const deployment of deployments) {
|
||||
const { litellmParamsObj, modelInfoObj, modelName } = deployment;
|
||||
|
||||
const new_model: Model = {
|
||||
model_name: modelName,
|
||||
litellm_params: litellmParamsObj,
|
||||
model_info: modelInfoObj,
|
||||
};
|
||||
|
||||
const response: any = await modelCreateCall(accessToken, new_model);
|
||||
console.log(`response for model create call: ${response["data"]}`);
|
||||
}
|
||||
|
||||
callback && callback();
|
||||
form.resetFields();
|
||||
|
|
|
@ -55,7 +55,7 @@ const ModelConnectionTest: React.FC<ModelConnectionTestProps> = ({
|
|||
|
||||
console.log("Result from prepareModelAddRequest:", result);
|
||||
|
||||
const { litellmParamsObj, modelInfoObj, modelName: returnedModelName } = result;
|
||||
const { litellmParamsObj, modelInfoObj, modelName: returnedModelName } = result[0];
|
||||
|
||||
const response = await testConnectionRequest(accessToken, litellmParamsObj, modelInfoObj?.mode);
|
||||
if (response.status === "success") {
|
||||
|
|
|
@ -13,7 +13,7 @@ import {
|
|||
TabPanel, TabPanels, DonutChart,
|
||||
Table, TableHead, TableRow,
|
||||
TableHeaderCell, TableBody, TableCell,
|
||||
Subtitle
|
||||
Subtitle, DateRangePicker, DateRangePickerValue
|
||||
} from "@tremor/react";
|
||||
import { AreaChart } from "@tremor/react";
|
||||
|
||||
|
@ -41,6 +41,12 @@ const NewUsagePage: React.FC<NewUsagePageProps> = ({
|
|||
metadata: any;
|
||||
}>({ results: [], metadata: {} });
|
||||
|
||||
// Add date range state
|
||||
const [dateValue, setDateValue] = useState<DateRangePickerValue>({
|
||||
from: new Date(Date.now() - 28 * 24 * 60 * 60 * 1000),
|
||||
to: new Date(),
|
||||
});
|
||||
|
||||
// Derived states from userSpendData
|
||||
const totalSpend = userSpendData.metadata?.total_spend || 0;
|
||||
|
||||
|
@ -168,22 +174,34 @@ const NewUsagePage: React.FC<NewUsagePageProps> = ({
|
|||
};
|
||||
|
||||
const fetchUserSpendData = async () => {
|
||||
if (!accessToken) return;
|
||||
const startTime = new Date(Date.now() - 28 * 24 * 60 * 60 * 1000);
|
||||
const endTime = new Date();
|
||||
if (!accessToken || !dateValue.from || !dateValue.to) return;
|
||||
const startTime = dateValue.from;
|
||||
const endTime = dateValue.to;
|
||||
const data = await userDailyActivityCall(accessToken, startTime, endTime);
|
||||
setUserSpendData(data);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchUserSpendData();
|
||||
}, [accessToken]);
|
||||
}, [accessToken, dateValue]);
|
||||
|
||||
const modelMetrics = processActivityData(userSpendData);
|
||||
|
||||
return (
|
||||
<div style={{ width: "100%" }} className="p-8">
|
||||
<Text>Experimental Usage page, using new `/user/daily/activity` endpoint.</Text>
|
||||
<Grid numItems={2} className="gap-2 w-full mb-4">
|
||||
<Col>
|
||||
<Text>Select Time Range</Text>
|
||||
<DateRangePicker
|
||||
enableSelect={true}
|
||||
value={dateValue}
|
||||
onValueChange={(value) => {
|
||||
setDateValue(value);
|
||||
}}
|
||||
/>
|
||||
</Col>
|
||||
</Grid>
|
||||
<TabGroup>
|
||||
<TabList variant="solid" className="mt-1">
|
||||
<Tab>Cost</Tab>
|
||||
|
|
|
@ -175,6 +175,14 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
|||
try {
|
||||
if (!accessToken) return;
|
||||
|
||||
let parsedMetadata = {};
|
||||
try {
|
||||
parsedMetadata = values.metadata ? JSON.parse(values.metadata) : {};
|
||||
} catch (e) {
|
||||
message.error("Invalid JSON in metadata field");
|
||||
return;
|
||||
}
|
||||
|
||||
const updateData = {
|
||||
team_id: teamId,
|
||||
team_alias: values.team_alias,
|
||||
|
@ -184,7 +192,7 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
|||
max_budget: values.max_budget,
|
||||
budget_duration: values.budget_duration,
|
||||
metadata: {
|
||||
...values.metadata,
|
||||
...parsedMetadata,
|
||||
guardrails: values.guardrails || []
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue