mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
LiteLLM Minor Fixes & Improvements (04/02/2025) (#9725)
* Add date picker to usage tab + Add reasoning_content token tracking across all providers on streaming (#9722) * feat(new_usage.tsx): add date picker for new usage tab allow user to look back on their usage data * feat(anthropic/chat/transformation.py): report reasoning tokens in completion token details allows usage tracking on how many reasoning tokens are actually being used * feat(streaming_chunk_builder.py): return reasoning_tokens in anthropic/openai streaming response allows tracking reasoning_token usage across providers * Fix update team metadata + fix bulk adding models on Ui (#9721) * fix(handle_add_model_submit.tsx): fix bulk adding models * fix(team_info.tsx): fix team metadata update Fixes https://github.com/BerriAI/litellm/issues/9689 * (v0) Unified file id - allow calling multiple providers with same file id (#9718) * feat(files_endpoints.py): initial commit adding 'target_model_names' support allow developer to specify all the models they want to call with the file * feat(files_endpoints.py): return unified files endpoint * test(test_files_endpoints.py): add validation test - if invalid purpose submitted * feat: more updates * feat: initial working commit of unified file id translation * fix: additional fixes * fix(router.py): remove model replace logic in jsonl on acreate_file enables file upload to work for chat completion requests as well * fix(files_endpoints.py): remove whitespace around model name * fix(azure/handler.py): return acreate_file with correct response type * fix: fix linting errors * test: fix mock test to run on github actions * fix: fix ruff errors * fix: fix file too large error * fix(utils.py): remove redundant var * test: modify test to work on github actions * test: update tests * test: more debug logs to understand ci/cd issue * test: fix test for respx * test: skip mock respx test fails on ci/cd - not clear why * fix: fix ruff check * fix: fix test * fix(model_connection_test.tsx): fix linting error * test: update unit tests
This commit is contained in:
parent
ad57b7b331
commit
0ce878e804
27 changed files with 889 additions and 96 deletions
|
@ -63,16 +63,17 @@ async def acreate_file(
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
kwargs["acreate_file"] = True
|
kwargs["acreate_file"] = True
|
||||||
|
|
||||||
# Use a partial function to pass your keyword arguments
|
call_args = {
|
||||||
func = partial(
|
"file": file,
|
||||||
create_file,
|
"purpose": purpose,
|
||||||
file,
|
"custom_llm_provider": custom_llm_provider,
|
||||||
purpose,
|
"extra_headers": extra_headers,
|
||||||
custom_llm_provider,
|
"extra_body": extra_body,
|
||||||
extra_headers,
|
|
||||||
extra_body,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
}
|
||||||
|
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(create_file, **call_args)
|
||||||
|
|
||||||
# Add the context to the function
|
# Add the context to the function
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
|
@ -92,7 +93,7 @@ async def acreate_file(
|
||||||
def create_file(
|
def create_file(
|
||||||
file: FileTypes,
|
file: FileTypes,
|
||||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -101,6 +102,8 @@ def create_file(
|
||||||
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||||
|
|
||||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||||
|
|
||||||
|
Specify either provider_list or custom_llm_provider.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
_is_async = kwargs.pop("acreate_file", False) is True
|
_is_async = kwargs.pop("acreate_file", False) is True
|
||||||
|
@ -120,7 +123,7 @@ def create_file(
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
and isinstance(timeout, httpx.Timeout)
|
and isinstance(timeout, httpx.Timeout)
|
||||||
and supports_httpx_timeout(custom_llm_provider) is False
|
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
||||||
):
|
):
|
||||||
read_timeout = timeout.read or 600
|
read_timeout = timeout.read or 600
|
||||||
timeout = read_timeout # default 10 min timeout
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
|
|
@ -457,8 +457,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
|
prompt_management_logger: Optional[CustomLogger] = None,
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
custom_logger = self.get_custom_logger_for_prompt_management(model)
|
custom_logger = (
|
||||||
|
prompt_management_logger
|
||||||
|
or self.get_custom_logger_for_prompt_management(model)
|
||||||
|
)
|
||||||
if custom_logger:
|
if custom_logger:
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict, List, Literal, Optional, Union, cast
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
ChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessage,
|
||||||
|
ChatCompletionFileObject,
|
||||||
ChatCompletionUserMessage,
|
ChatCompletionUserMessage,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||||
|
@ -292,3 +293,58 @@ def get_completion_messages(
|
||||||
messages, assistant_continue_message, ensure_alternating_roles
|
messages, assistant_continue_message, ensure_alternating_roles
|
||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Gets file ids from messages
|
||||||
|
"""
|
||||||
|
file_ids = []
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "user":
|
||||||
|
content = message.get("content")
|
||||||
|
if content:
|
||||||
|
if isinstance(content, str):
|
||||||
|
continue
|
||||||
|
for c in content:
|
||||||
|
if c["type"] == "file":
|
||||||
|
file_object = cast(ChatCompletionFileObject, c)
|
||||||
|
file_object_file_field = file_object["file"]
|
||||||
|
file_id = file_object_file_field.get("file_id")
|
||||||
|
if file_id:
|
||||||
|
file_ids.append(file_id)
|
||||||
|
return file_ids
|
||||||
|
|
||||||
|
|
||||||
|
def update_messages_with_model_file_ids(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
model_id: str,
|
||||||
|
model_file_id_mapping: Dict[str, Dict[str, str]],
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Updates messages with model file ids.
|
||||||
|
|
||||||
|
model_file_id_mapping: Dict[str, Dict[str, str]] = {
|
||||||
|
"litellm_proxy/file_id": {
|
||||||
|
"model_id": "provider_file_id"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "user":
|
||||||
|
content = message.get("content")
|
||||||
|
if content:
|
||||||
|
if isinstance(content, str):
|
||||||
|
continue
|
||||||
|
for c in content:
|
||||||
|
if c["type"] == "file":
|
||||||
|
file_object = cast(ChatCompletionFileObject, c)
|
||||||
|
file_object_file_field = file_object["file"]
|
||||||
|
file_id = file_object_file_field.get("file_id")
|
||||||
|
if file_id:
|
||||||
|
provider_file_id = (
|
||||||
|
model_file_id_mapping.get(file_id, {}).get(model_id)
|
||||||
|
or file_id
|
||||||
|
)
|
||||||
|
file_object_file_field["file_id"] = provider_file_id
|
||||||
|
return messages
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionAssistantContentValue,
|
ChatCompletionAssistantContentValue,
|
||||||
|
@ -9,7 +9,9 @@ from litellm.types.llms.openai import (
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ChatCompletionAudioResponse,
|
ChatCompletionAudioResponse,
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
|
Choices,
|
||||||
CompletionTokensDetails,
|
CompletionTokensDetails,
|
||||||
|
CompletionTokensDetailsWrapper,
|
||||||
Function,
|
Function,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
@ -203,14 +205,14 @@ class ChunkProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_combined_content(
|
def get_combined_content(
|
||||||
self, chunks: List[Dict[str, Any]]
|
self, chunks: List[Dict[str, Any]], delta_key: str = "content"
|
||||||
) -> ChatCompletionAssistantContentValue:
|
) -> ChatCompletionAssistantContentValue:
|
||||||
content_list: List[str] = []
|
content_list: List[str] = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
choices = chunk["choices"]
|
choices = chunk["choices"]
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
delta = choice.get("delta", {})
|
delta = choice.get("delta", {})
|
||||||
content = delta.get("content", "")
|
content = delta.get(delta_key, "")
|
||||||
if content is None:
|
if content is None:
|
||||||
continue # openai v1.0.0 sets content = None for chunks
|
continue # openai v1.0.0 sets content = None for chunks
|
||||||
content_list.append(content)
|
content_list.append(content)
|
||||||
|
@ -221,6 +223,11 @@ class ChunkProcessor:
|
||||||
# Update the "content" field within the response dictionary
|
# Update the "content" field within the response dictionary
|
||||||
return combined_content
|
return combined_content
|
||||||
|
|
||||||
|
def get_combined_reasoning_content(
|
||||||
|
self, chunks: List[Dict[str, Any]]
|
||||||
|
) -> ChatCompletionAssistantContentValue:
|
||||||
|
return self.get_combined_content(chunks, delta_key="reasoning_content")
|
||||||
|
|
||||||
def get_combined_audio_content(
|
def get_combined_audio_content(
|
||||||
self, chunks: List[Dict[str, Any]]
|
self, chunks: List[Dict[str, Any]]
|
||||||
) -> ChatCompletionAudioResponse:
|
) -> ChatCompletionAudioResponse:
|
||||||
|
@ -296,12 +303,27 @@ class ChunkProcessor:
|
||||||
"prompt_tokens_details": prompt_tokens_details,
|
"prompt_tokens_details": prompt_tokens_details,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def count_reasoning_tokens(self, response: ModelResponse) -> int:
|
||||||
|
reasoning_tokens = 0
|
||||||
|
for choice in response.choices:
|
||||||
|
if (
|
||||||
|
hasattr(cast(Choices, choice).message, "reasoning_content")
|
||||||
|
and cast(Choices, choice).message.reasoning_content is not None
|
||||||
|
):
|
||||||
|
reasoning_tokens += token_counter(
|
||||||
|
text=cast(Choices, choice).message.reasoning_content,
|
||||||
|
count_response_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return reasoning_tokens
|
||||||
|
|
||||||
def calculate_usage(
|
def calculate_usage(
|
||||||
self,
|
self,
|
||||||
chunks: List[Union[Dict[str, Any], ModelResponse]],
|
chunks: List[Union[Dict[str, Any], ModelResponse]],
|
||||||
model: str,
|
model: str,
|
||||||
completion_output: str,
|
completion_output: str,
|
||||||
messages: Optional[List] = None,
|
messages: Optional[List] = None,
|
||||||
|
reasoning_tokens: Optional[int] = None,
|
||||||
) -> Usage:
|
) -> Usage:
|
||||||
"""
|
"""
|
||||||
Calculate usage for the given chunks.
|
Calculate usage for the given chunks.
|
||||||
|
@ -382,6 +404,19 @@ class ChunkProcessor:
|
||||||
) # for anthropic
|
) # for anthropic
|
||||||
if completion_tokens_details is not None:
|
if completion_tokens_details is not None:
|
||||||
returned_usage.completion_tokens_details = completion_tokens_details
|
returned_usage.completion_tokens_details = completion_tokens_details
|
||||||
|
|
||||||
|
if reasoning_tokens is not None:
|
||||||
|
if returned_usage.completion_tokens_details is None:
|
||||||
|
returned_usage.completion_tokens_details = (
|
||||||
|
CompletionTokensDetailsWrapper(reasoning_tokens=reasoning_tokens)
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
returned_usage.completion_tokens_details is not None
|
||||||
|
and returned_usage.completion_tokens_details.reasoning_tokens is None
|
||||||
|
):
|
||||||
|
returned_usage.completion_tokens_details.reasoning_tokens = (
|
||||||
|
reasoning_tokens
|
||||||
|
)
|
||||||
if prompt_tokens_details is not None:
|
if prompt_tokens_details is not None:
|
||||||
returned_usage.prompt_tokens_details = prompt_tokens_details
|
returned_usage.prompt_tokens_details = prompt_tokens_details
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.anthropic import (
|
from litellm.types.llms.anthropic import (
|
||||||
AnthropicChatCompletionUsageBlock,
|
|
||||||
ContentBlockDelta,
|
ContentBlockDelta,
|
||||||
ContentBlockStart,
|
ContentBlockStart,
|
||||||
ContentBlockStop,
|
ContentBlockStop,
|
||||||
|
@ -32,13 +31,13 @@ from litellm.types.llms.anthropic import (
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionThinkingBlock,
|
ChatCompletionThinkingBlock,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionUsageBlock,
|
|
||||||
)
|
)
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
Delta,
|
Delta,
|
||||||
GenericStreamingChunk,
|
GenericStreamingChunk,
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
StreamingChoices,
|
StreamingChoices,
|
||||||
|
Usage,
|
||||||
)
|
)
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||||
|
|
||||||
|
@ -487,10 +486,8 @@ class ModelResponseIterator:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_usage(
|
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
||||||
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
usage_block = Usage(
|
||||||
) -> AnthropicChatCompletionUsageBlock:
|
|
||||||
usage_block = AnthropicChatCompletionUsageBlock(
|
|
||||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
||||||
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
||||||
|
@ -581,7 +578,7 @@ class ModelResponseIterator:
|
||||||
text = ""
|
text = ""
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
usage: Optional[ChatCompletionUsageBlock] = None
|
usage: Optional[Usage] = None
|
||||||
provider_specific_fields: Dict[str, Any] = {}
|
provider_specific_fields: Dict[str, Any] = {}
|
||||||
reasoning_content: Optional[str] = None
|
reasoning_content: Optional[str] = None
|
||||||
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
||||||
|
|
|
@ -33,9 +33,16 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import CompletionTokensDetailsWrapper
|
||||||
from litellm.types.utils import Message as LitellmMessage
|
from litellm.types.utils import Message as LitellmMessage
|
||||||
from litellm.types.utils import PromptTokensDetailsWrapper
|
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||||
from litellm.utils import ModelResponse, Usage, add_dummy_tool, has_tool_call_blocks
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
add_dummy_tool,
|
||||||
|
has_tool_call_blocks,
|
||||||
|
token_counter,
|
||||||
|
)
|
||||||
|
|
||||||
from ..common_utils import AnthropicError, process_anthropic_headers
|
from ..common_utils import AnthropicError, process_anthropic_headers
|
||||||
|
|
||||||
|
@ -772,6 +779,15 @@ class AnthropicConfig(BaseConfig):
|
||||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||||
cached_tokens=cache_read_input_tokens
|
cached_tokens=cache_read_input_tokens
|
||||||
)
|
)
|
||||||
|
completion_token_details = (
|
||||||
|
CompletionTokensDetailsWrapper(
|
||||||
|
reasoning_tokens=token_counter(
|
||||||
|
text=reasoning_content, count_response_tokens=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if reasoning_content
|
||||||
|
else None
|
||||||
|
)
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
|
@ -780,6 +796,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
prompt_tokens_details=prompt_tokens_details,
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
|
completion_tokens_details=completion_token_details,
|
||||||
)
|
)
|
||||||
|
|
||||||
setattr(model_response, "usage", usage) # type: ignore
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
|
|
|
@ -28,11 +28,11 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
self,
|
self,
|
||||||
create_file_data: CreateFileRequest,
|
create_file_data: CreateFileRequest,
|
||||||
openai_client: AsyncAzureOpenAI,
|
openai_client: AsyncAzureOpenAI,
|
||||||
) -> FileObject:
|
) -> OpenAIFileObject:
|
||||||
verbose_logger.debug("create_file_data=%s", create_file_data)
|
verbose_logger.debug("create_file_data=%s", create_file_data)
|
||||||
response = await openai_client.files.create(**create_file_data)
|
response = await openai_client.files.create(**create_file_data)
|
||||||
verbose_logger.debug("create_file_response=%s", response)
|
verbose_logger.debug("create_file_response=%s", response)
|
||||||
return response
|
return OpenAIFileObject(**response.model_dump())
|
||||||
|
|
||||||
def create_file(
|
def create_file(
|
||||||
self,
|
self,
|
||||||
|
@ -66,7 +66,7 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||||
)
|
)
|
||||||
return self.acreate_file( # type: ignore
|
return self.acreate_file(
|
||||||
create_file_data=create_file_data, openai_client=openai_client
|
create_file_data=create_file_data, openai_client=openai_client
|
||||||
)
|
)
|
||||||
response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)
|
response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)
|
||||||
|
|
|
@ -110,7 +110,10 @@ from .litellm_core_utils.fallback_utils import (
|
||||||
async_completion_with_fallbacks,
|
async_completion_with_fallbacks,
|
||||||
completion_with_fallbacks,
|
completion_with_fallbacks,
|
||||||
)
|
)
|
||||||
from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages
|
from .litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
get_completion_messages,
|
||||||
|
update_messages_with_model_file_ids,
|
||||||
|
)
|
||||||
from .litellm_core_utils.prompt_templates.factory import (
|
from .litellm_core_utils.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
|
@ -953,7 +956,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
||||||
litellm_params = {} # used to prevent unbound var errors
|
litellm_params = {} # used to prevent unbound var errors
|
||||||
## PROMPT MANAGEMENT HOOKS ##
|
## PROMPT MANAGEMENT HOOKS ##
|
||||||
|
|
||||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
|
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
|
@ -1068,6 +1070,15 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
if eos_token:
|
if eos_token:
|
||||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||||
|
|
||||||
|
if kwargs.get("model_file_id_mapping"):
|
||||||
|
messages = update_messages_with_model_file_ids(
|
||||||
|
messages=messages,
|
||||||
|
model_id=kwargs.get("model_info", {}).get("id", None),
|
||||||
|
model_file_id_mapping=cast(
|
||||||
|
Dict[str, Dict[str, str]], kwargs.get("model_file_id_mapping")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
provider_config: Optional[BaseConfig] = None
|
provider_config: Optional[BaseConfig] = None
|
||||||
if custom_llm_provider is not None and custom_llm_provider in [
|
if custom_llm_provider is not None and custom_llm_provider in [
|
||||||
provider.value for provider in LlmProviders
|
provider.value for provider in LlmProviders
|
||||||
|
@ -5799,6 +5810,19 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
"content"
|
"content"
|
||||||
] = processor.get_combined_content(content_chunks)
|
] = processor.get_combined_content(content_chunks)
|
||||||
|
|
||||||
|
reasoning_chunks = [
|
||||||
|
chunk
|
||||||
|
for chunk in chunks
|
||||||
|
if len(chunk["choices"]) > 0
|
||||||
|
and "reasoning_content" in chunk["choices"][0]["delta"]
|
||||||
|
and chunk["choices"][0]["delta"]["reasoning_content"] is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(reasoning_chunks) > 0:
|
||||||
|
response["choices"][0]["message"][
|
||||||
|
"reasoning_content"
|
||||||
|
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
||||||
|
|
||||||
audio_chunks = [
|
audio_chunks = [
|
||||||
chunk
|
chunk
|
||||||
for chunk in chunks
|
for chunk in chunks
|
||||||
|
@ -5813,11 +5837,14 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
|
|
||||||
completion_output = get_content_from_model_response(response)
|
completion_output = get_content_from_model_response(response)
|
||||||
|
|
||||||
|
reasoning_tokens = processor.count_reasoning_tokens(response)
|
||||||
|
|
||||||
usage = processor.calculate_usage(
|
usage = processor.calculate_usage(
|
||||||
chunks=chunks,
|
chunks=chunks,
|
||||||
model=model,
|
model=model,
|
||||||
completion_output=completion_output,
|
completion_output=completion_output,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
reasoning_tokens=reasoning_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
setattr(response, "usage", usage)
|
setattr(response, "usage", usage)
|
||||||
|
|
|
@ -1,18 +1,17 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "gpt-4o"
|
- model_name: "gpt-4o-azure"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/gpt-4o
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: http://0.0.0.0:8090
|
api_base: os.environ/AZURE_API_BASE
|
||||||
rpm: 3
|
|
||||||
- model_name: "gpt-4o-mini-openai"
|
- model_name: "gpt-4o-mini-openai"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-4o-mini
|
model: gpt-4o-mini
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: "openai/*"
|
- model_name: "openai/*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/*
|
model: openai/*
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: "bedrock-nova"
|
- model_name: "bedrock-nova"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: us.amazon.nova-pro-v1:0
|
model: us.amazon.nova-pro-v1:0
|
||||||
|
|
|
@ -2688,6 +2688,10 @@ class PrismaCompatibleUpdateDBModel(TypedDict, total=False):
|
||||||
updated_by: str
|
updated_by: str
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialEnums(enum.Enum):
|
||||||
|
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy/"
|
||||||
|
|
||||||
|
|
||||||
class SpecialManagementEndpointEnums(enum.Enum):
|
class SpecialManagementEndpointEnums(enum.Enum):
|
||||||
DEFAULT_ORGANIZATION = "default_organization"
|
DEFAULT_ORGANIZATION = "default_organization"
|
||||||
|
|
||||||
|
|
|
@ -1 +1,36 @@
|
||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
from . import *
|
from . import *
|
||||||
|
from .cache_control_check import _PROXY_CacheControlCheck
|
||||||
|
from .managed_files import _PROXY_LiteLLMManagedFiles
|
||||||
|
from .max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
|
from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler
|
||||||
|
|
||||||
|
# List of all available hooks that can be enabled
|
||||||
|
PROXY_HOOKS = {
|
||||||
|
"max_budget_limiter": _PROXY_MaxBudgetLimiter,
|
||||||
|
"managed_files": _PROXY_LiteLLMManagedFiles,
|
||||||
|
"parallel_request_limiter": _PROXY_MaxParallelRequestsHandler,
|
||||||
|
"cache_control_check": _PROXY_CacheControlCheck,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_proxy_hook(
|
||||||
|
hook_name: Union[
|
||||||
|
Literal[
|
||||||
|
"max_budget_limiter",
|
||||||
|
"managed_files",
|
||||||
|
"parallel_request_limiter",
|
||||||
|
"cache_control_check",
|
||||||
|
],
|
||||||
|
str,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Factory method to get a proxy hook instance by name
|
||||||
|
"""
|
||||||
|
if hook_name not in PROXY_HOOKS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown hook: {hook_name}. Available hooks: {list(PROXY_HOOKS.keys())}"
|
||||||
|
)
|
||||||
|
return PROXY_HOOKS[hook_name]
|
||||||
|
|
145
litellm/proxy/hooks/managed_files.py
Normal file
145
litellm/proxy/hooks/managed_files.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
# What is this?
|
||||||
|
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
||||||
|
|
||||||
|
from litellm import verbose_logger
|
||||||
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
get_file_ids_from_messages,
|
||||||
|
)
|
||||||
|
from litellm.proxy._types import CallTypes, SpecialEnums, UserAPIKeyAuth
|
||||||
|
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||||
|
|
||||||
|
Span = Union[_Span, Any]
|
||||||
|
InternalUsageCache = _InternalUsageCache
|
||||||
|
else:
|
||||||
|
Span = Any
|
||||||
|
InternalUsageCache = Any
|
||||||
|
|
||||||
|
|
||||||
|
class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||||
|
self.internal_usage_cache = internal_usage_cache
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: Dict,
|
||||||
|
call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
|
"rerank",
|
||||||
|
],
|
||||||
|
) -> Union[Exception, str, Dict, None]:
|
||||||
|
"""
|
||||||
|
- Detect litellm_proxy/ file_id
|
||||||
|
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
|
||||||
|
"""
|
||||||
|
if call_type == CallTypes.completion.value:
|
||||||
|
messages = data.get("messages")
|
||||||
|
if messages:
|
||||||
|
file_ids = get_file_ids_from_messages(messages)
|
||||||
|
if file_ids:
|
||||||
|
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||||
|
file_ids, user_api_key_dict.parent_otel_span
|
||||||
|
)
|
||||||
|
data["model_file_id_mapping"] = model_file_id_mapping
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def get_model_file_id_mapping(
|
||||||
|
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get model-specific file IDs for a list of proxy file IDs.
|
||||||
|
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
|
||||||
|
|
||||||
|
1. Get all the litellm_proxy/ file_ids from the messages
|
||||||
|
2. For each file_id, search for cache keys matching the pattern file_id:*
|
||||||
|
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
|
||||||
|
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"litellm_proxy/file_id": {
|
||||||
|
"model_id": "model_file_id"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||||
|
litellm_managed_file_ids = []
|
||||||
|
|
||||||
|
for file_id in file_ids:
|
||||||
|
## CHECK IF FILE ID IS MANAGED BY LITELM
|
||||||
|
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||||
|
litellm_managed_file_ids.append(file_id)
|
||||||
|
|
||||||
|
if litellm_managed_file_ids:
|
||||||
|
# Get all cache keys matching the pattern file_id:*
|
||||||
|
for file_id in litellm_managed_file_ids:
|
||||||
|
# Search for any cache key starting with this file_id
|
||||||
|
cached_values = cast(
|
||||||
|
Dict[str, str],
|
||||||
|
await self.internal_usage_cache.async_get_cache(
|
||||||
|
key=file_id, litellm_parent_otel_span=litellm_parent_otel_span
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if cached_values:
|
||||||
|
file_id_mapping[file_id] = cached_values
|
||||||
|
return file_id_mapping
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def return_unified_file_id(
|
||||||
|
file_objects: List[OpenAIFileObject],
|
||||||
|
purpose: OpenAIFilesPurpose,
|
||||||
|
internal_usage_cache: InternalUsageCache,
|
||||||
|
litellm_parent_otel_span: Span,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
unified_file_id = SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value + str(
|
||||||
|
uuid.uuid4()
|
||||||
|
)
|
||||||
|
|
||||||
|
## CREATE RESPONSE OBJECT
|
||||||
|
response = OpenAIFileObject(
|
||||||
|
id=unified_file_id,
|
||||||
|
object="file",
|
||||||
|
purpose=cast(OpenAIFilesPurpose, purpose),
|
||||||
|
created_at=file_objects[0].created_at,
|
||||||
|
bytes=1234,
|
||||||
|
filename=str(datetime.now().timestamp()),
|
||||||
|
status="uploaded",
|
||||||
|
)
|
||||||
|
|
||||||
|
## STORE RESPONSE IN DB + CACHE
|
||||||
|
stored_values: Dict[str, str] = {}
|
||||||
|
for file_object in file_objects:
|
||||||
|
model_id = file_object._hidden_params.get("model_id")
|
||||||
|
if model_id is None:
|
||||||
|
verbose_logger.warning(
|
||||||
|
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
file_id = file_object.id
|
||||||
|
stored_values[model_id] = file_id
|
||||||
|
await internal_usage_cache.async_set_cache(
|
||||||
|
key=unified_file_id,
|
||||||
|
value=stored_values,
|
||||||
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional, cast, get_args
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
|
@ -31,7 +31,10 @@ from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessin
|
||||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||||
get_custom_llm_provider_from_request_body,
|
get_custom_llm_provider_from_request_body,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
from litellm.router import Router
|
from litellm.router import Router
|
||||||
|
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -104,6 +107,53 @@ def is_known_model(model: Optional[str], llm_router: Optional[Router]) -> bool:
|
||||||
return is_in_list
|
return is_in_list
|
||||||
|
|
||||||
|
|
||||||
|
async def _deprecated_loadbalanced_create_file(
|
||||||
|
llm_router: Optional[Router],
|
||||||
|
router_model: str,
|
||||||
|
_create_file_request: CreateFileRequest,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await llm_router.acreate_file(model=router_model, **_create_file_request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def create_file_for_each_model(
|
||||||
|
llm_router: Optional[Router],
|
||||||
|
_create_file_request: CreateFileRequest,
|
||||||
|
target_model_names_list: List[str],
|
||||||
|
purpose: OpenAIFilesPurpose,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
responses = []
|
||||||
|
for model in target_model_names_list:
|
||||||
|
individual_response = await llm_router.acreate_file(
|
||||||
|
model=model, **_create_file_request
|
||||||
|
)
|
||||||
|
responses.append(individual_response)
|
||||||
|
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
|
||||||
|
file_objects=responses,
|
||||||
|
purpose=purpose,
|
||||||
|
internal_usage_cache=proxy_logging_obj.internal_usage_cache,
|
||||||
|
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/{provider}/v1/files",
|
"/{provider}/v1/files",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
@ -123,6 +173,7 @@ async def create_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
purpose: str = Form(...),
|
purpose: str = Form(...),
|
||||||
|
target_model_names: str = Form(default=""),
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
custom_llm_provider: str = Form(default="openai"),
|
custom_llm_provider: str = Form(default="openai"),
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
|
@ -162,8 +213,25 @@ async def create_file(
|
||||||
or await get_custom_llm_provider_from_request_body(request=request)
|
or await get_custom_llm_provider_from_request_body(request=request)
|
||||||
or "openai"
|
or "openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
target_model_names_list = (
|
||||||
|
target_model_names.split(",") if target_model_names else []
|
||||||
|
)
|
||||||
|
target_model_names_list = [model.strip() for model in target_model_names_list]
|
||||||
# Prepare the data for forwarding
|
# Prepare the data for forwarding
|
||||||
|
|
||||||
|
# Replace with:
|
||||||
|
valid_purposes = get_args(OpenAIFilesPurpose)
|
||||||
|
if purpose not in valid_purposes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"Invalid purpose: {purpose}. Must be one of: {valid_purposes}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Cast purpose to OpenAIFilesPurpose type
|
||||||
|
purpose = cast(OpenAIFilesPurpose, purpose)
|
||||||
|
|
||||||
data = {"purpose": purpose}
|
data = {"purpose": purpose}
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
|
@ -192,21 +260,25 @@ async def create_file(
|
||||||
|
|
||||||
_create_file_request = CreateFileRequest(file=file_data, **data)
|
_create_file_request = CreateFileRequest(file=file_data, **data)
|
||||||
|
|
||||||
|
response: Optional[OpenAIFileObject] = None
|
||||||
if (
|
if (
|
||||||
litellm.enable_loadbalancing_on_batch_endpoints is True
|
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||||
and is_router_model
|
and is_router_model
|
||||||
and router_model is not None
|
and router_model is not None
|
||||||
):
|
):
|
||||||
if llm_router is None:
|
response = await _deprecated_loadbalanced_create_file(
|
||||||
raise HTTPException(
|
llm_router=llm_router,
|
||||||
status_code=500,
|
router_model=router_model,
|
||||||
detail={
|
_create_file_request=_create_file_request,
|
||||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
)
|
||||||
},
|
elif target_model_names_list:
|
||||||
)
|
response = await create_file_for_each_model(
|
||||||
|
llm_router=llm_router,
|
||||||
response = await llm_router.acreate_file(
|
_create_file_request=_create_file_request,
|
||||||
model=router_model, **_create_file_request
|
target_model_names_list=target_model_names_list,
|
||||||
|
purpose=purpose,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# get configs for custom_llm_provider
|
# get configs for custom_llm_provider
|
||||||
|
@ -220,6 +292,11 @@ async def create_file(
|
||||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
|
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": "Failed to create file. Please try again."},
|
||||||
|
)
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.update_request_status(
|
proxy_logging_obj.update_request_status(
|
||||||
|
@ -248,12 +325,11 @@ async def create_file(
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.exception(
|
||||||
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=getattr(e, "message", str(e.detail)),
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
|
|
@ -76,6 +76,7 @@ from litellm.proxy.db.create_views import (
|
||||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||||
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
||||||
from litellm.proxy.db.prisma_client import PrismaWrapper
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
from litellm.proxy.hooks import PROXY_HOOKS, get_proxy_hook
|
||||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
|
@ -352,10 +353,19 @@ class ProxyLogging:
|
||||||
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
|
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
|
||||||
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
|
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
|
||||||
|
|
||||||
|
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
|
||||||
|
for hook in PROXY_HOOKS:
|
||||||
|
proxy_hook = get_proxy_hook(hook)
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
expected_args = inspect.getfullargspec(proxy_hook).args
|
||||||
|
if "internal_usage_cache" in expected_args:
|
||||||
|
litellm.logging_callback_manager.add_litellm_callback(proxy_hook(self.internal_usage_cache)) # type: ignore
|
||||||
|
else:
|
||||||
|
litellm.logging_callback_manager.add_litellm_callback(proxy_hook()) # type: ignore
|
||||||
|
|
||||||
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
||||||
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore
|
self._add_proxy_hooks(llm_router)
|
||||||
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore
|
|
||||||
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore
|
|
||||||
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if isinstance(callback, str):
|
if isinstance(callback, str):
|
||||||
|
|
|
@ -68,10 +68,7 @@ from litellm.router_utils.add_retry_fallback_headers import (
|
||||||
add_fallback_headers_to_response,
|
add_fallback_headers_to_response,
|
||||||
add_retry_headers_to_response,
|
add_retry_headers_to_response,
|
||||||
)
|
)
|
||||||
from litellm.router_utils.batch_utils import (
|
from litellm.router_utils.batch_utils import _get_router_metadata_variable_name
|
||||||
_get_router_metadata_variable_name,
|
|
||||||
replace_model_in_jsonl,
|
|
||||||
)
|
|
||||||
from litellm.router_utils.client_initalization_utils import InitalizeCachedClient
|
from litellm.router_utils.client_initalization_utils import InitalizeCachedClient
|
||||||
from litellm.router_utils.clientside_credential_handler import (
|
from litellm.router_utils.clientside_credential_handler import (
|
||||||
get_dynamic_litellm_params,
|
get_dynamic_litellm_params,
|
||||||
|
@ -105,7 +102,12 @@ from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
||||||
increment_deployment_successes_for_current_minute,
|
increment_deployment_successes_for_current_minute,
|
||||||
)
|
)
|
||||||
from litellm.scheduler import FlowItem, Scheduler
|
from litellm.scheduler import FlowItem, Scheduler
|
||||||
from litellm.types.llms.openai import AllMessageValues, Batch, FileObject, FileTypes
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
Batch,
|
||||||
|
FileTypes,
|
||||||
|
OpenAIFileObject,
|
||||||
|
)
|
||||||
from litellm.types.router import (
|
from litellm.types.router import (
|
||||||
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||||
VALID_LITELLM_ENVIRONMENTS,
|
VALID_LITELLM_ENVIRONMENTS,
|
||||||
|
@ -2703,7 +2705,7 @@ class Router:
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> FileObject:
|
) -> OpenAIFileObject:
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["original_function"] = self._acreate_file
|
kwargs["original_function"] = self._acreate_file
|
||||||
|
@ -2727,7 +2729,7 @@ class Router:
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> FileObject:
|
) -> OpenAIFileObject:
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||||
|
@ -2754,9 +2756,9 @@ class Router:
|
||||||
stripped_model, custom_llm_provider, _, _ = get_llm_provider(
|
stripped_model, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
model=data["model"]
|
model=data["model"]
|
||||||
)
|
)
|
||||||
kwargs["file"] = replace_model_in_jsonl(
|
# kwargs["file"] = replace_model_in_jsonl(
|
||||||
file_content=kwargs["file"], new_model_name=stripped_model
|
# file_content=kwargs["file"], new_model_name=stripped_model
|
||||||
)
|
# )
|
||||||
|
|
||||||
response = litellm.acreate_file(
|
response = litellm.acreate_file(
|
||||||
**{
|
**{
|
||||||
|
@ -2796,6 +2798,7 @@ class Router:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
|
f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
)
|
)
|
||||||
|
|
||||||
return response # type: ignore
|
return response # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.exception(
|
verbose_router_logger.exception(
|
||||||
|
|
14
litellm/router_utils/common_utils.py
Normal file
14
litellm/router_utils/common_utils.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
from litellm.types.router import CredentialLiteLLMParams
|
||||||
|
|
||||||
|
|
||||||
|
def get_litellm_params_sensitive_credential_hash(litellm_params: dict) -> str:
|
||||||
|
"""
|
||||||
|
Hash of the credential params, used for mapping the file id to the right model
|
||||||
|
"""
|
||||||
|
sensitive_params = CredentialLiteLLMParams(**litellm_params)
|
||||||
|
return hashlib.sha256(
|
||||||
|
json.dumps(sensitive_params.model_dump()).encode()
|
||||||
|
).hexdigest()
|
|
@ -234,7 +234,18 @@ class Thread(BaseModel):
|
||||||
"""The object type, which is always `thread`."""
|
"""The object type, which is always `thread`."""
|
||||||
|
|
||||||
|
|
||||||
OpenAICreateFileRequestOptionalParams = Literal["purpose",]
|
OpenAICreateFileRequestOptionalParams = Literal["purpose"]
|
||||||
|
|
||||||
|
OpenAIFilesPurpose = Literal[
|
||||||
|
"assistants",
|
||||||
|
"assistants_output",
|
||||||
|
"batch",
|
||||||
|
"batch_output",
|
||||||
|
"fine-tune",
|
||||||
|
"fine-tune-results",
|
||||||
|
"vision",
|
||||||
|
"user_data",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFileObject(BaseModel):
|
class OpenAIFileObject(BaseModel):
|
||||||
|
@ -253,16 +264,7 @@ class OpenAIFileObject(BaseModel):
|
||||||
object: Literal["file"]
|
object: Literal["file"]
|
||||||
"""The object type, which is always `file`."""
|
"""The object type, which is always `file`."""
|
||||||
|
|
||||||
purpose: Literal[
|
purpose: OpenAIFilesPurpose
|
||||||
"assistants",
|
|
||||||
"assistants_output",
|
|
||||||
"batch",
|
|
||||||
"batch_output",
|
|
||||||
"fine-tune",
|
|
||||||
"fine-tune-results",
|
|
||||||
"vision",
|
|
||||||
"user_data",
|
|
||||||
]
|
|
||||||
"""The intended purpose of the file.
|
"""The intended purpose of the file.
|
||||||
|
|
||||||
Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`,
|
Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`,
|
||||||
|
@ -286,6 +288,8 @@ class OpenAIFileObject(BaseModel):
|
||||||
`error` field on `fine_tuning.job`.
|
`error` field on `fine_tuning.job`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_hidden_params: dict = {}
|
||||||
|
|
||||||
|
|
||||||
# OpenAI Files Types
|
# OpenAI Files Types
|
||||||
class CreateFileRequest(TypedDict, total=False):
|
class CreateFileRequest(TypedDict, total=False):
|
||||||
|
|
|
@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from ..exceptions import RateLimitError
|
from ..exceptions import RateLimitError
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
from .embedding import EmbeddingRequest
|
from .embedding import EmbeddingRequest
|
||||||
|
from .llms.openai import OpenAIFileObject
|
||||||
from .llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
from .llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
from .utils import ModelResponse, ProviderSpecificModelInfo
|
from .utils import ModelResponse, ProviderSpecificModelInfo
|
||||||
|
|
||||||
|
@ -703,3 +704,12 @@ class GenericBudgetWindowDetails(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]]
|
OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]]
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_RouterFileObject(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Tracking the litellm params hash, used for mapping the file id to the right model
|
||||||
|
"""
|
||||||
|
|
||||||
|
litellm_params_sensitive_credential_hash: str
|
||||||
|
file_object: OpenAIFileObject
|
||||||
|
|
|
@ -1886,6 +1886,7 @@ all_litellm_params = [
|
||||||
"logger_fn",
|
"logger_fn",
|
||||||
"verbose",
|
"verbose",
|
||||||
"custom_llm_provider",
|
"custom_llm_provider",
|
||||||
|
"model_file_id_mapping",
|
||||||
"litellm_logging_obj",
|
"litellm_logging_obj",
|
||||||
"litellm_call_id",
|
"litellm_call_id",
|
||||||
"use_client",
|
"use_client",
|
||||||
|
|
|
@ -17,6 +17,7 @@ import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
|
CompletionTokensDetailsWrapper,
|
||||||
Delta,
|
Delta,
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
PromptTokensDetailsWrapper,
|
PromptTokensDetailsWrapper,
|
||||||
|
@ -430,11 +431,18 @@ async def test_streaming_handler_with_usage(
|
||||||
completion_tokens=392,
|
completion_tokens=392,
|
||||||
prompt_tokens=1799,
|
prompt_tokens=1799,
|
||||||
total_tokens=2191,
|
total_tokens=2191,
|
||||||
completion_tokens_details=None,
|
completion_tokens_details=CompletionTokensDetailsWrapper( # <-- This has a value
|
||||||
|
accepted_prediction_tokens=None,
|
||||||
|
audio_tokens=None,
|
||||||
|
reasoning_tokens=0,
|
||||||
|
rejected_prediction_tokens=None,
|
||||||
|
text_tokens=None,
|
||||||
|
),
|
||||||
prompt_tokens_details=PromptTokensDetailsWrapper(
|
prompt_tokens_details=PromptTokensDetailsWrapper(
|
||||||
audio_tokens=None, cached_tokens=1796, text_tokens=None, image_tokens=None
|
audio_tokens=None, cached_tokens=1796, text_tokens=None, image_tokens=None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
final_chunk = ModelResponseStream(
|
final_chunk = ModelResponseStream(
|
||||||
id="chatcmpl-87291500-d8c5-428e-b187-36fe5a4c97ab",
|
id="chatcmpl-87291500-d8c5-428e-b187-36fe5a4c97ab",
|
||||||
created=1742056047,
|
created=1742056047,
|
||||||
|
@ -510,7 +518,13 @@ async def test_streaming_with_usage_and_logging(sync_mode: bool):
|
||||||
completion_tokens=392,
|
completion_tokens=392,
|
||||||
prompt_tokens=1799,
|
prompt_tokens=1799,
|
||||||
total_tokens=2191,
|
total_tokens=2191,
|
||||||
completion_tokens_details=None,
|
completion_tokens_details=CompletionTokensDetailsWrapper(
|
||||||
|
accepted_prediction_tokens=None,
|
||||||
|
audio_tokens=None,
|
||||||
|
reasoning_tokens=0,
|
||||||
|
rejected_prediction_tokens=None,
|
||||||
|
text_tokens=None,
|
||||||
|
),
|
||||||
prompt_tokens_details=PromptTokensDetailsWrapper(
|
prompt_tokens_details=PromptTokensDetailsWrapper(
|
||||||
audio_tokens=None,
|
audio_tokens=None,
|
||||||
cached_tokens=1796,
|
cached_tokens=1796,
|
||||||
|
|
306
tests/litellm/proxy/openai_files_endpoint/test_files_endpoint.py
Normal file
306
tests/litellm/proxy/openai_files_endpoint/test_files_endpoint.py
Normal file
|
@ -0,0 +1,306 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import ANY
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
from litellm.proxy._types import LiteLLM_UserTableFiltered, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.hooks import get_proxy_hook
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import ui_view_users
|
||||||
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.proxy.proxy_server import hash_token
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_router() -> Router:
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "azure-gpt-3-5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "azure_api_key",
|
||||||
|
"api_base": "azure_api_base",
|
||||||
|
"api_version": "azure_api_version",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "azure-gpt-3-5-turbo-id",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-3.5-turbo",
|
||||||
|
"api_key": "openai_api_key",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "gpt-3.5-turbo-id",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gemini-2.0-flash",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gemini/gemini-2.0-flash",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "gemini-2.0-flash-id",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return llm_router
|
||||||
|
|
||||||
|
|
||||||
|
def setup_proxy_logging_object(monkeypatch, llm_router: Router) -> ProxyLogging:
|
||||||
|
proxy_logging_object = ProxyLogging(
|
||||||
|
user_api_key_cache=DualCache(default_in_memory_ttl=1)
|
||||||
|
)
|
||||||
|
proxy_logging_object._add_proxy_hooks(llm_router)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_object
|
||||||
|
)
|
||||||
|
return proxy_logging_object
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_purpose(mocker: MockerFixture, monkeypatch, llm_router: Router):
|
||||||
|
"""
|
||||||
|
Asserts 'create_file' is called with the correct arguments
|
||||||
|
"""
|
||||||
|
# Create a simple test file content
|
||||||
|
test_file_content = b"test audio content"
|
||||||
|
test_file = ("test.wav", test_file_content, "audio/wav")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/v1/files",
|
||||||
|
files={"file": test_file},
|
||||||
|
data={
|
||||||
|
"purpose": "my-bad-purpose",
|
||||||
|
"target_model_names": ["azure-gpt-3-5-turbo", "gpt-3.5-turbo"],
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
print(f"response: {response.json()}")
|
||||||
|
assert "Invalid purpose: my-bad-purpose" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_create_audio_file(mocker: MockerFixture, monkeypatch, llm_router: Router):
|
||||||
|
"""
|
||||||
|
Asserts 'create_file' is called with the correct arguments
|
||||||
|
"""
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
mock_create_file = mocker.patch("litellm.files.main.create_file")
|
||||||
|
|
||||||
|
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", llm_router)
|
||||||
|
|
||||||
|
# Create a simple test file content
|
||||||
|
test_file_content = b"test audio content"
|
||||||
|
test_file = ("test.wav", test_file_content, "audio/wav")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/v1/files",
|
||||||
|
files={"file": test_file},
|
||||||
|
data={
|
||||||
|
"purpose": "user_data",
|
||||||
|
"target_model_names": "azure-gpt-3-5-turbo, gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"response: {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Get all calls made to create_file
|
||||||
|
calls = mock_create_file.call_args_list
|
||||||
|
|
||||||
|
# Check for Azure call
|
||||||
|
azure_call_found = False
|
||||||
|
for call in calls:
|
||||||
|
kwargs = call.kwargs
|
||||||
|
if (
|
||||||
|
kwargs.get("custom_llm_provider") == "azure"
|
||||||
|
and kwargs.get("model") == "azure/chatgpt-v-2"
|
||||||
|
and kwargs.get("api_key") == "azure_api_key"
|
||||||
|
):
|
||||||
|
azure_call_found = True
|
||||||
|
break
|
||||||
|
assert (
|
||||||
|
azure_call_found
|
||||||
|
), f"Azure call not found with expected parameters. Calls: {calls}"
|
||||||
|
|
||||||
|
# Check for OpenAI call
|
||||||
|
openai_call_found = False
|
||||||
|
for call in calls:
|
||||||
|
kwargs = call.kwargs
|
||||||
|
if (
|
||||||
|
kwargs.get("custom_llm_provider") == "openai"
|
||||||
|
and kwargs.get("model") == "openai/gpt-3.5-turbo"
|
||||||
|
and kwargs.get("api_key") == "openai_api_key"
|
||||||
|
):
|
||||||
|
openai_call_found = True
|
||||||
|
break
|
||||||
|
assert openai_call_found, "OpenAI call not found with expected parameters"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="mock respx fails on ci/cd - unclear why")
|
||||||
|
def test_create_file_and_call_chat_completion_e2e(
|
||||||
|
mocker: MockerFixture, monkeypatch, llm_router: Router
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
1. Create a file
|
||||||
|
2. Call a chat completion with the file
|
||||||
|
3. Assert the file is used in the chat completion
|
||||||
|
"""
|
||||||
|
# Create and enable respx mock instance
|
||||||
|
mock = respx.mock()
|
||||||
|
mock.start()
|
||||||
|
try:
|
||||||
|
from litellm.types.llms.openai import OpenAIFileObject
|
||||||
|
|
||||||
|
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", llm_router)
|
||||||
|
proxy_logging_object = setup_proxy_logging_object(monkeypatch, llm_router)
|
||||||
|
|
||||||
|
# Create a simple test file content
|
||||||
|
test_file_content = b"test audio content"
|
||||||
|
test_file = ("test.wav", test_file_content, "audio/wav")
|
||||||
|
|
||||||
|
# Mock the file creation response
|
||||||
|
mock_file_response = OpenAIFileObject(
|
||||||
|
id="test-file-id",
|
||||||
|
object="file",
|
||||||
|
bytes=123,
|
||||||
|
created_at=1234567890,
|
||||||
|
filename="test.wav",
|
||||||
|
purpose="user_data",
|
||||||
|
status="uploaded",
|
||||||
|
)
|
||||||
|
mock_file_response._hidden_params = {"model_id": "gemini-2.0-flash-id"}
|
||||||
|
mocker.patch.object(llm_router, "acreate_file", return_value=mock_file_response)
|
||||||
|
|
||||||
|
# Mock the Gemini API call using respx
|
||||||
|
mock_gemini_response = {
|
||||||
|
"candidates": [
|
||||||
|
{"content": {"parts": [{"text": "This is a test audio file"}]}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the Gemini API endpoint with a more flexible pattern
|
||||||
|
gemini_route = mock.post(
|
||||||
|
url__regex=r".*generativelanguage\.googleapis\.com.*"
|
||||||
|
).mock(
|
||||||
|
return_value=respx.MockResponse(status_code=200, json=mock_gemini_response),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print updated mock setup
|
||||||
|
print("\nAfter Adding Gemini Route:")
|
||||||
|
print("==========================")
|
||||||
|
print(f"Number of mocked routes: {len(mock.routes)}")
|
||||||
|
for route in mock.routes:
|
||||||
|
print(f"Mocked Route: {route}")
|
||||||
|
print(f"Pattern: {route.pattern}")
|
||||||
|
|
||||||
|
## CREATE FILE
|
||||||
|
file = client.post(
|
||||||
|
"/v1/files",
|
||||||
|
files={"file": test_file},
|
||||||
|
data={
|
||||||
|
"purpose": "user_data",
|
||||||
|
"target_model_names": "gemini-2.0-flash, gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nAfter File Creation:")
|
||||||
|
print("====================")
|
||||||
|
print(f"File creation status: {file.status_code}")
|
||||||
|
print(f"Recorded calls so far: {len(mock.calls)}")
|
||||||
|
for call in mock.calls:
|
||||||
|
print(f"Call made to: {call.request.method} {call.request.url}")
|
||||||
|
|
||||||
|
assert file.status_code == 200
|
||||||
|
assert file.json()["id"] != "test-file-id" # unified file id used
|
||||||
|
|
||||||
|
## USE FILE IN CHAT COMPLETION
|
||||||
|
try:
|
||||||
|
completion = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "gemini-2.0-flash",
|
||||||
|
"modalities": ["text", "audio"],
|
||||||
|
"audio": {"voice": "alloy", "format": "wav"},
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is in this recording?"},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"file_id": file.json()["id"],
|
||||||
|
"filename": "my-test-name",
|
||||||
|
"format": "audio/wav",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"drop_params": True,
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-key"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error: {e}")
|
||||||
|
|
||||||
|
print("\nError occurred during chat completion:")
|
||||||
|
print("=====================================")
|
||||||
|
print("\nFinal Mock State:")
|
||||||
|
print("=================")
|
||||||
|
print(f"Total mocked routes: {len(mock.routes)}")
|
||||||
|
for route in mock.routes:
|
||||||
|
print(f"\nMocked Route: {route}")
|
||||||
|
print(f" Called: {route.called}")
|
||||||
|
|
||||||
|
print("\nActual Requests Made:")
|
||||||
|
print("=====================")
|
||||||
|
print(f"Total calls recorded: {len(mock.calls)}")
|
||||||
|
for idx, call in enumerate(mock.calls):
|
||||||
|
print(f"\nCall {idx + 1}:")
|
||||||
|
print(f" Method: {call.request.method}")
|
||||||
|
print(f" URL: {call.request.url}")
|
||||||
|
print(f" Headers: {dict(call.request.headers)}")
|
||||||
|
try:
|
||||||
|
print(f" Body: {call.request.content.decode()}")
|
||||||
|
except:
|
||||||
|
print(" Body: <could not decode>")
|
||||||
|
|
||||||
|
# Verify Gemini API was called
|
||||||
|
assert gemini_route.called, "Gemini API was not called"
|
||||||
|
|
||||||
|
# Print the call details
|
||||||
|
print("\nGemini API Call Details:")
|
||||||
|
print(f"URL: {gemini_route.calls.last.request.url}")
|
||||||
|
print(f"Method: {gemini_route.calls.last.request.method}")
|
||||||
|
print(f"Headers: {dict(gemini_route.calls.last.request.headers)}")
|
||||||
|
print(f"Content: {gemini_route.calls.last.request.content.decode()}")
|
||||||
|
print(f"Response: {gemini_route.calls.last.response.content.decode()}")
|
||||||
|
|
||||||
|
assert "test-file-id" in gemini_route.calls.last.request.content.decode()
|
||||||
|
finally:
|
||||||
|
# Stop the mock
|
||||||
|
mock.stop()
|
|
@ -39,7 +39,6 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
|
||||||
with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch(
|
with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch(
|
||||||
"litellm.proxy.proxy_server.store_model_in_db", False
|
"litellm.proxy.proxy_server.store_model_in_db", False
|
||||||
): # set store_model_in_db to False
|
): # set store_model_in_db to False
|
||||||
|
|
||||||
# Test when store_model_in_db is False
|
# Test when store_model_in_db is False
|
||||||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||||||
general_settings={},
|
general_settings={},
|
||||||
|
@ -57,7 +56,6 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
|
||||||
with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch(
|
with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch(
|
||||||
"litellm.proxy.proxy_server.store_model_in_db", True
|
"litellm.proxy.proxy_server.store_model_in_db", True
|
||||||
), patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True):
|
), patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True):
|
||||||
|
|
||||||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||||||
general_settings={},
|
general_settings={},
|
||||||
prisma_client=mock_prisma_client,
|
prisma_client=mock_prisma_client,
|
||||||
|
|
|
@ -1116,3 +1116,6 @@ def test_anthropic_thinking_in_assistant_message(model):
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ export const prepareModelAddRequest = async (
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a deployment for each mapping
|
// Create a deployment for each mapping
|
||||||
|
const deployments = [];
|
||||||
for (const mapping of modelMappings) {
|
for (const mapping of modelMappings) {
|
||||||
const litellmParamsObj: Record<string, any> = {};
|
const litellmParamsObj: Record<string, any> = {};
|
||||||
const modelInfoObj: Record<string, any> = {};
|
const modelInfoObj: Record<string, any> = {};
|
||||||
|
@ -142,8 +143,10 @@ export const prepareModelAddRequest = async (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { litellmParamsObj, modelInfoObj, modelName };
|
deployments.push({ litellmParamsObj, modelInfoObj, modelName });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return deployments;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
message.error("Failed to create model: " + error, 10);
|
message.error("Failed to create model: " + error, 10);
|
||||||
}
|
}
|
||||||
|
@ -156,22 +159,25 @@ export const handleAddModelSubmit = async (
|
||||||
callback?: () => void,
|
callback?: () => void,
|
||||||
) => {
|
) => {
|
||||||
try {
|
try {
|
||||||
const result = await prepareModelAddRequest(values, accessToken, form);
|
const deployments = await prepareModelAddRequest(values, accessToken, form);
|
||||||
|
|
||||||
if (!result) {
|
if (!deployments || deployments.length === 0) {
|
||||||
return; // Exit if preparation failed
|
return; // Exit if preparation failed or no deployments
|
||||||
}
|
}
|
||||||
|
|
||||||
const { litellmParamsObj, modelInfoObj, modelName } = result;
|
// Create each deployment
|
||||||
|
for (const deployment of deployments) {
|
||||||
|
const { litellmParamsObj, modelInfoObj, modelName } = deployment;
|
||||||
|
|
||||||
const new_model: Model = {
|
const new_model: Model = {
|
||||||
model_name: modelName,
|
model_name: modelName,
|
||||||
litellm_params: litellmParamsObj,
|
litellm_params: litellmParamsObj,
|
||||||
model_info: modelInfoObj,
|
model_info: modelInfoObj,
|
||||||
};
|
};
|
||||||
|
|
||||||
const response: any = await modelCreateCall(accessToken, new_model);
|
const response: any = await modelCreateCall(accessToken, new_model);
|
||||||
console.log(`response for model create call: ${response["data"]}`);
|
console.log(`response for model create call: ${response["data"]}`);
|
||||||
|
}
|
||||||
|
|
||||||
callback && callback();
|
callback && callback();
|
||||||
form.resetFields();
|
form.resetFields();
|
||||||
|
|
|
@ -55,7 +55,7 @@ const ModelConnectionTest: React.FC<ModelConnectionTestProps> = ({
|
||||||
|
|
||||||
console.log("Result from prepareModelAddRequest:", result);
|
console.log("Result from prepareModelAddRequest:", result);
|
||||||
|
|
||||||
const { litellmParamsObj, modelInfoObj, modelName: returnedModelName } = result;
|
const { litellmParamsObj, modelInfoObj, modelName: returnedModelName } = result[0];
|
||||||
|
|
||||||
const response = await testConnectionRequest(accessToken, litellmParamsObj, modelInfoObj?.mode);
|
const response = await testConnectionRequest(accessToken, litellmParamsObj, modelInfoObj?.mode);
|
||||||
if (response.status === "success") {
|
if (response.status === "success") {
|
||||||
|
|
|
@ -13,7 +13,7 @@ import {
|
||||||
TabPanel, TabPanels, DonutChart,
|
TabPanel, TabPanels, DonutChart,
|
||||||
Table, TableHead, TableRow,
|
Table, TableHead, TableRow,
|
||||||
TableHeaderCell, TableBody, TableCell,
|
TableHeaderCell, TableBody, TableCell,
|
||||||
Subtitle
|
Subtitle, DateRangePicker, DateRangePickerValue
|
||||||
} from "@tremor/react";
|
} from "@tremor/react";
|
||||||
import { AreaChart } from "@tremor/react";
|
import { AreaChart } from "@tremor/react";
|
||||||
|
|
||||||
|
@ -41,6 +41,12 @@ const NewUsagePage: React.FC<NewUsagePageProps> = ({
|
||||||
metadata: any;
|
metadata: any;
|
||||||
}>({ results: [], metadata: {} });
|
}>({ results: [], metadata: {} });
|
||||||
|
|
||||||
|
// Add date range state
|
||||||
|
const [dateValue, setDateValue] = useState<DateRangePickerValue>({
|
||||||
|
from: new Date(Date.now() - 28 * 24 * 60 * 60 * 1000),
|
||||||
|
to: new Date(),
|
||||||
|
});
|
||||||
|
|
||||||
// Derived states from userSpendData
|
// Derived states from userSpendData
|
||||||
const totalSpend = userSpendData.metadata?.total_spend || 0;
|
const totalSpend = userSpendData.metadata?.total_spend || 0;
|
||||||
|
|
||||||
|
@ -168,22 +174,34 @@ const NewUsagePage: React.FC<NewUsagePageProps> = ({
|
||||||
};
|
};
|
||||||
|
|
||||||
const fetchUserSpendData = async () => {
|
const fetchUserSpendData = async () => {
|
||||||
if (!accessToken) return;
|
if (!accessToken || !dateValue.from || !dateValue.to) return;
|
||||||
const startTime = new Date(Date.now() - 28 * 24 * 60 * 60 * 1000);
|
const startTime = dateValue.from;
|
||||||
const endTime = new Date();
|
const endTime = dateValue.to;
|
||||||
const data = await userDailyActivityCall(accessToken, startTime, endTime);
|
const data = await userDailyActivityCall(accessToken, startTime, endTime);
|
||||||
setUserSpendData(data);
|
setUserSpendData(data);
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
fetchUserSpendData();
|
fetchUserSpendData();
|
||||||
}, [accessToken]);
|
}, [accessToken, dateValue]);
|
||||||
|
|
||||||
const modelMetrics = processActivityData(userSpendData);
|
const modelMetrics = processActivityData(userSpendData);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div style={{ width: "100%" }} className="p-8">
|
<div style={{ width: "100%" }} className="p-8">
|
||||||
<Text>Experimental Usage page, using new `/user/daily/activity` endpoint.</Text>
|
<Text>Experimental Usage page, using new `/user/daily/activity` endpoint.</Text>
|
||||||
|
<Grid numItems={2} className="gap-2 w-full mb-4">
|
||||||
|
<Col>
|
||||||
|
<Text>Select Time Range</Text>
|
||||||
|
<DateRangePicker
|
||||||
|
enableSelect={true}
|
||||||
|
value={dateValue}
|
||||||
|
onValueChange={(value) => {
|
||||||
|
setDateValue(value);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
|
</Grid>
|
||||||
<TabGroup>
|
<TabGroup>
|
||||||
<TabList variant="solid" className="mt-1">
|
<TabList variant="solid" className="mt-1">
|
||||||
<Tab>Cost</Tab>
|
<Tab>Cost</Tab>
|
||||||
|
|
|
@ -175,6 +175,14 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||||
try {
|
try {
|
||||||
if (!accessToken) return;
|
if (!accessToken) return;
|
||||||
|
|
||||||
|
let parsedMetadata = {};
|
||||||
|
try {
|
||||||
|
parsedMetadata = values.metadata ? JSON.parse(values.metadata) : {};
|
||||||
|
} catch (e) {
|
||||||
|
message.error("Invalid JSON in metadata field");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const updateData = {
|
const updateData = {
|
||||||
team_id: teamId,
|
team_id: teamId,
|
||||||
team_alias: values.team_alias,
|
team_alias: values.team_alias,
|
||||||
|
@ -184,7 +192,7 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||||
max_budget: values.max_budget,
|
max_budget: values.max_budget,
|
||||||
budget_duration: values.budget_duration,
|
budget_duration: values.budget_duration,
|
||||||
metadata: {
|
metadata: {
|
||||||
...values.metadata,
|
...parsedMetadata,
|
||||||
guardrails: values.guardrails || []
|
guardrails: values.guardrails || []
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue