mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Compare commits
No commits in common. "main" and "v1.67.3.dev1" have entirely different histories.
main
...
v1.67.3.de
21 changed files with 43 additions and 405 deletions
|
@ -37,7 +37,6 @@ class SagemakerConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_new_tokens: Optional[int] = None
|
max_new_tokens: Optional[int] = None
|
||||||
max_completion_tokens: Optional[int] = None
|
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
return_full_text: Optional[bool] = None
|
return_full_text: Optional[bool] = None
|
||||||
|
@ -45,7 +44,6 @@ class SagemakerConfig(BaseConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: Optional[int] = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
return_full_text: Optional[bool] = None,
|
return_full_text: Optional[bool] = None,
|
||||||
|
@ -67,7 +65,7 @@ class SagemakerConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> List:
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
return ["stream", "temperature", "max_tokens", "max_completion_tokens", "top_p", "stop", "n"]
|
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
|
@ -104,8 +102,6 @@ class SagemakerConfig(BaseConfig):
|
||||||
if value == 0:
|
if value == 0:
|
||||||
value = 1
|
value = 1
|
||||||
optional_params["max_new_tokens"] = value
|
optional_params["max_new_tokens"] = value
|
||||||
if param == "max_completion_tokens":
|
|
||||||
optional_params["max_new_tokens"] = value
|
|
||||||
non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
|
non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
|
@ -7058,17 +7058,6 @@
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"command-a-03-2025": {
|
|
||||||
"max_tokens": 8000,
|
|
||||||
"max_input_tokens": 256000,
|
|
||||||
"max_output_tokens": 8000,
|
|
||||||
"input_cost_per_token": 0.0000025,
|
|
||||||
"output_cost_per_token": 0.00001,
|
|
||||||
"litellm_provider": "cohere_chat",
|
|
||||||
"mode": "chat",
|
|
||||||
"supports_function_calling": true,
|
|
||||||
"supports_tool_choice": true
|
|
||||||
},
|
|
||||||
"command-r": {
|
"command-r": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
|
|
@ -108,13 +108,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
proxy_config: ProxyConfig,
|
proxy_config: ProxyConfig,
|
||||||
route_type: Literal[
|
route_type: Literal["acompletion", "aresponses", "_arealtime"],
|
||||||
"acompletion",
|
|
||||||
"aresponses",
|
|
||||||
"_arealtime",
|
|
||||||
"aget_responses",
|
|
||||||
"adelete_responses",
|
|
||||||
],
|
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
user_model: Optional[str] = None,
|
user_model: Optional[str] = None,
|
||||||
user_temperature: Optional[float] = None,
|
user_temperature: Optional[float] = None,
|
||||||
|
@ -184,13 +178,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
route_type: Literal[
|
route_type: Literal["acompletion", "aresponses", "_arealtime"],
|
||||||
"acompletion",
|
|
||||||
"aresponses",
|
|
||||||
"_arealtime",
|
|
||||||
"aget_responses",
|
|
||||||
"adelete_responses",
|
|
||||||
],
|
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
proxy_config: ProxyConfig,
|
proxy_config: ProxyConfig,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from fastapi import Request, UploadFile, status
|
from fastapi import Request, UploadFile, status
|
||||||
|
@ -147,10 +147,10 @@ def check_file_size_under_limit(
|
||||||
|
|
||||||
if llm_router is not None and request_data["model"] in router_model_names:
|
if llm_router is not None and request_data["model"] in router_model_names:
|
||||||
try:
|
try:
|
||||||
deployment: Optional[Deployment] = (
|
deployment: Optional[
|
||||||
llm_router.get_deployment_by_model_group_name(
|
Deployment
|
||||||
model_group_name=request_data["model"]
|
] = llm_router.get_deployment_by_model_group_name(
|
||||||
)
|
model_group_name=request_data["model"]
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
deployment
|
deployment
|
||||||
|
@ -185,23 +185,3 @@ def check_file_size_under_limit(
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def get_form_data(request: Request) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Read form data from request
|
|
||||||
|
|
||||||
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
|
||||||
"""
|
|
||||||
form = await request.form()
|
|
||||||
form_data = dict(form)
|
|
||||||
parsed_form_data: dict[str, Any] = {}
|
|
||||||
for key, value in form_data.items():
|
|
||||||
|
|
||||||
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
|
|
||||||
if key.endswith("[]"):
|
|
||||||
clean_key = key[:-2]
|
|
||||||
parsed_form_data.setdefault(clean_key, []).append(value)
|
|
||||||
else:
|
|
||||||
parsed_form_data[key] = value
|
|
||||||
return parsed_form_data
|
|
||||||
|
|
|
@ -1,8 +1,16 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: openai/*
|
- model_name: azure-computer-use-preview
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/*
|
model: azure/computer-use-preview
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: mock-api-key
|
||||||
|
api_version: mock-api-version
|
||||||
|
api_base: https://mock-endpoint.openai.azure.com
|
||||||
|
- model_name: azure-computer-use-preview
|
||||||
|
litellm_params:
|
||||||
|
model: azure/computer-use-preview-2
|
||||||
|
api_key: mock-api-key-2
|
||||||
|
api_version: mock-api-version-2
|
||||||
|
api_base: https://mock-endpoint-2.openai.azure.com
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
optional_pre_call_checks: ["responses_api_deployment_check"]
|
optional_pre_call_checks: ["responses_api_deployment_check"]
|
||||||
|
|
|
@ -179,7 +179,6 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form
|
||||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||||
_read_request_body,
|
_read_request_body,
|
||||||
check_file_size_under_limit,
|
check_file_size_under_limit,
|
||||||
get_form_data,
|
|
||||||
)
|
)
|
||||||
from litellm.proxy.common_utils.load_config_utils import (
|
from litellm.proxy.common_utils.load_config_utils import (
|
||||||
get_config_file_contents_from_gcs,
|
get_config_file_contents_from_gcs,
|
||||||
|
@ -805,9 +804,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||||
dual_cache=user_api_key_cache
|
dual_cache=user_api_key_cache
|
||||||
)
|
)
|
||||||
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
||||||
redis_usage_cache: Optional[RedisCache] = (
|
redis_usage_cache: Optional[
|
||||||
None # redis cache used for tracking spend, tpm/rpm limits
|
RedisCache
|
||||||
)
|
] = None # redis cache used for tracking spend, tpm/rpm limits
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
user_custom_key_generate = None
|
user_custom_key_generate = None
|
||||||
user_custom_sso = None
|
user_custom_sso = None
|
||||||
|
@ -1133,9 +1132,9 @@ async def update_cache( # noqa: PLR0915
|
||||||
_id = "team_id:{}".format(team_id)
|
_id = "team_id:{}".format(team_id)
|
||||||
try:
|
try:
|
||||||
# Fetch the existing cost for the given user
|
# Fetch the existing cost for the given user
|
||||||
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
|
existing_spend_obj: Optional[
|
||||||
await user_api_key_cache.async_get_cache(key=_id)
|
LiteLLM_TeamTable
|
||||||
)
|
] = await user_api_key_cache.async_get_cache(key=_id)
|
||||||
if existing_spend_obj is None:
|
if existing_spend_obj is None:
|
||||||
# do nothing if team not in api key cache
|
# do nothing if team not in api key cache
|
||||||
return
|
return
|
||||||
|
@ -2807,9 +2806,9 @@ async def initialize( # noqa: PLR0915
|
||||||
user_api_base = api_base
|
user_api_base = api_base
|
||||||
dynamic_config[user_model]["api_base"] = api_base
|
dynamic_config[user_model]["api_base"] = api_base
|
||||||
if api_version:
|
if api_version:
|
||||||
os.environ["AZURE_API_VERSION"] = (
|
os.environ[
|
||||||
api_version # set this for azure - litellm can read this from the env
|
"AZURE_API_VERSION"
|
||||||
)
|
] = api_version # set this for azure - litellm can read this from the env
|
||||||
if max_tokens: # model-specific param
|
if max_tokens: # model-specific param
|
||||||
dynamic_config[user_model]["max_tokens"] = max_tokens
|
dynamic_config[user_model]["max_tokens"] = max_tokens
|
||||||
if temperature: # model-specific param
|
if temperature: # model-specific param
|
||||||
|
@ -4121,7 +4120,7 @@ async def audio_transcriptions(
|
||||||
data: Dict = {}
|
data: Dict = {}
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
form_data = await get_form_data(request)
|
form_data = await request.form()
|
||||||
data = {key: value for key, value in form_data.items() if key != "file"}
|
data = {key: value for key, value in form_data.items() if key != "file"}
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
|
@ -7759,9 +7758,9 @@ async def get_config_list(
|
||||||
hasattr(sub_field_info, "description")
|
hasattr(sub_field_info, "description")
|
||||||
and sub_field_info.description is not None
|
and sub_field_info.description is not None
|
||||||
):
|
):
|
||||||
nested_fields[idx].field_description = (
|
nested_fields[
|
||||||
sub_field_info.description
|
idx
|
||||||
)
|
].field_description = sub_field_info.description
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
_stored_in_db = None
|
_stored_in_db = None
|
||||||
|
|
|
@ -106,50 +106,8 @@ async def get_response(
|
||||||
-H "Authorization: Bearer sk-1234"
|
-H "Authorization: Bearer sk-1234"
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import (
|
# TODO: Implement response retrieval logic
|
||||||
_read_request_body,
|
pass
|
||||||
general_settings,
|
|
||||||
llm_router,
|
|
||||||
proxy_config,
|
|
||||||
proxy_logging_obj,
|
|
||||||
select_data_generator,
|
|
||||||
user_api_base,
|
|
||||||
user_max_tokens,
|
|
||||||
user_model,
|
|
||||||
user_request_timeout,
|
|
||||||
user_temperature,
|
|
||||||
version,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = await _read_request_body(request=request)
|
|
||||||
data["response_id"] = response_id
|
|
||||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
|
||||||
try:
|
|
||||||
return await processor.base_process_llm_request(
|
|
||||||
request=request,
|
|
||||||
fastapi_response=fastapi_response,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
route_type="aget_responses",
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
llm_router=llm_router,
|
|
||||||
general_settings=general_settings,
|
|
||||||
proxy_config=proxy_config,
|
|
||||||
select_data_generator=select_data_generator,
|
|
||||||
model=None,
|
|
||||||
user_model=user_model,
|
|
||||||
user_temperature=user_temperature,
|
|
||||||
user_request_timeout=user_request_timeout,
|
|
||||||
user_max_tokens=user_max_tokens,
|
|
||||||
user_api_base=user_api_base,
|
|
||||||
version=version,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise await processor._handle_llm_api_exception(
|
|
||||||
e=e,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
version=version,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
|
@ -178,50 +136,8 @@ async def delete_response(
|
||||||
-H "Authorization: Bearer sk-1234"
|
-H "Authorization: Bearer sk-1234"
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import (
|
# TODO: Implement response deletion logic
|
||||||
_read_request_body,
|
pass
|
||||||
general_settings,
|
|
||||||
llm_router,
|
|
||||||
proxy_config,
|
|
||||||
proxy_logging_obj,
|
|
||||||
select_data_generator,
|
|
||||||
user_api_base,
|
|
||||||
user_max_tokens,
|
|
||||||
user_model,
|
|
||||||
user_request_timeout,
|
|
||||||
user_temperature,
|
|
||||||
version,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = await _read_request_body(request=request)
|
|
||||||
data["response_id"] = response_id
|
|
||||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
|
||||||
try:
|
|
||||||
return await processor.base_process_llm_request(
|
|
||||||
request=request,
|
|
||||||
fastapi_response=fastapi_response,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
route_type="adelete_responses",
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
llm_router=llm_router,
|
|
||||||
general_settings=general_settings,
|
|
||||||
proxy_config=proxy_config,
|
|
||||||
select_data_generator=select_data_generator,
|
|
||||||
model=None,
|
|
||||||
user_model=user_model,
|
|
||||||
user_temperature=user_temperature,
|
|
||||||
user_request_timeout=user_request_timeout,
|
|
||||||
user_max_tokens=user_max_tokens,
|
|
||||||
user_api_base=user_api_base,
|
|
||||||
version=version,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise await processor._handle_llm_api_exception(
|
|
||||||
e=e,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
version=version,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|
|
@ -47,8 +47,6 @@ async def route_request(
|
||||||
"amoderation",
|
"amoderation",
|
||||||
"arerank",
|
"arerank",
|
||||||
"aresponses",
|
"aresponses",
|
||||||
"aget_responses",
|
|
||||||
"adelete_responses",
|
|
||||||
"_arealtime", # private function for realtime API
|
"_arealtime", # private function for realtime API
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
|
|
|
@ -176,16 +176,6 @@ class ResponsesAPIRequestUtils:
|
||||||
response_id=response_id,
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
|
|
||||||
"""Get the model_id from the response_id"""
|
|
||||||
if response_id is None:
|
|
||||||
return None
|
|
||||||
decoded_response_id = (
|
|
||||||
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
|
|
||||||
)
|
|
||||||
return decoded_response_id.get("model_id") or None
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseAPILoggingUtils:
|
class ResponseAPILoggingUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -739,12 +739,6 @@ class Router:
|
||||||
litellm.afile_content, call_type="afile_content"
|
litellm.afile_content, call_type="afile_content"
|
||||||
)
|
)
|
||||||
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
||||||
self.aget_responses = self.factory_function(
|
|
||||||
litellm.aget_responses, call_type="aget_responses"
|
|
||||||
)
|
|
||||||
self.adelete_responses = self.factory_function(
|
|
||||||
litellm.adelete_responses, call_type="adelete_responses"
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||||
"""
|
"""
|
||||||
|
@ -3087,8 +3081,6 @@ class Router:
|
||||||
"anthropic_messages",
|
"anthropic_messages",
|
||||||
"aresponses",
|
"aresponses",
|
||||||
"responses",
|
"responses",
|
||||||
"aget_responses",
|
|
||||||
"adelete_responses",
|
|
||||||
"afile_delete",
|
"afile_delete",
|
||||||
"afile_content",
|
"afile_content",
|
||||||
] = "assistants",
|
] = "assistants",
|
||||||
|
@ -3143,11 +3135,6 @@ class Router:
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif call_type in ("aget_responses", "adelete_responses"):
|
|
||||||
return await self._init_responses_api_endpoints(
|
|
||||||
original_function=original_function,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif call_type in ("afile_delete", "afile_content"):
|
elif call_type in ("afile_delete", "afile_content"):
|
||||||
return await self._ageneric_api_call_with_fallbacks(
|
return await self._ageneric_api_call_with_fallbacks(
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
|
@ -3158,28 +3145,6 @@ class Router:
|
||||||
|
|
||||||
return async_wrapper
|
return async_wrapper
|
||||||
|
|
||||||
async def _init_responses_api_endpoints(
|
|
||||||
self,
|
|
||||||
original_function: Callable,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the Responses API endpoints on the router.
|
|
||||||
|
|
||||||
GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id.
|
|
||||||
"""
|
|
||||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
|
||||||
|
|
||||||
model_id = ResponsesAPIRequestUtils.get_model_id_from_response_id(
|
|
||||||
kwargs.get("response_id")
|
|
||||||
)
|
|
||||||
if model_id is not None:
|
|
||||||
kwargs["model"] = model_id
|
|
||||||
return await self._ageneric_api_call_with_fallbacks(
|
|
||||||
original_function=original_function,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _pass_through_assistants_endpoint_factory(
|
async def _pass_through_assistants_endpoint_factory(
|
||||||
self,
|
self,
|
||||||
original_function: Callable,
|
original_function: Callable,
|
||||||
|
|
|
@ -7058,17 +7058,6 @@
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"command-a-03-2025": {
|
|
||||||
"max_tokens": 8000,
|
|
||||||
"max_input_tokens": 256000,
|
|
||||||
"max_output_tokens": 8000,
|
|
||||||
"input_cost_per_token": 0.0000025,
|
|
||||||
"output_cost_per_token": 0.00001,
|
|
||||||
"litellm_provider": "cohere_chat",
|
|
||||||
"mode": "chat",
|
|
||||||
"supports_function_calling": true,
|
|
||||||
"supports_tool_choice": true
|
|
||||||
},
|
|
||||||
"command-r": {
|
"command-r": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
||||||
|
|
||||||
sys.path.insert(0, os.path.abspath("../../../../.."))
|
sys.path.insert(0, os.path.abspath("../../../../.."))
|
||||||
from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder
|
from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder
|
||||||
from litellm.llms.sagemaker.completion.transformation import SagemakerConfig
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_aiter_bytes_unicode_decode_error():
|
async def test_aiter_bytes_unicode_decode_error():
|
||||||
|
@ -95,37 +95,3 @@ async def test_aiter_bytes_valid_chunk_followed_by_unicode_error():
|
||||||
# Verify we got our valid chunk despite the subsequent error
|
# Verify we got our valid chunk despite the subsequent error
|
||||||
assert len(chunks) == 1
|
assert len(chunks) == 1
|
||||||
assert chunks[0]["text"] == "hello" # Verify the content of the valid chunk
|
assert chunks[0]["text"] == "hello" # Verify the content of the valid chunk
|
||||||
|
|
||||||
class TestSagemakerTransform:
|
|
||||||
def setup_method(self):
|
|
||||||
self.config = SagemakerConfig()
|
|
||||||
self.model = "test"
|
|
||||||
self.logging_obj = MagicMock()
|
|
||||||
|
|
||||||
def test_map_mistral_params(self):
|
|
||||||
"""Test that parameters are correctly mapped"""
|
|
||||||
test_params = {"temperature": 0.7, "max_tokens": 200, "max_completion_tokens": 256}
|
|
||||||
|
|
||||||
result = self.config.map_openai_params(
|
|
||||||
non_default_params=test_params,
|
|
||||||
optional_params={},
|
|
||||||
model=self.model,
|
|
||||||
drop_params=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The function should properly map max_completion_tokens to max_tokens and override max_tokens
|
|
||||||
assert result == {"temperature": 0.7, "max_new_tokens": 256}
|
|
||||||
|
|
||||||
def test_mistral_max_tokens_backward_compat(self):
|
|
||||||
"""Test that parameters are correctly mapped"""
|
|
||||||
test_params = {"temperature": 0.7, "max_tokens": 200,}
|
|
||||||
|
|
||||||
result = self.config.map_openai_params(
|
|
||||||
non_default_params=test_params,
|
|
||||||
optional_params={},
|
|
||||||
model=self.model,
|
|
||||||
drop_params=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The function should properly map max_tokens if max_completion_tokens is not provided
|
|
||||||
assert result == {"temperature": 0.7, "max_new_tokens": 200}
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||||
_read_request_body,
|
_read_request_body,
|
||||||
_safe_get_request_parsed_body,
|
_safe_get_request_parsed_body,
|
||||||
_safe_set_request_parsed_body,
|
_safe_set_request_parsed_body,
|
||||||
get_form_data,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,53 +147,3 @@ async def test_circular_reference_handling():
|
||||||
assert (
|
assert (
|
||||||
"proxy_server_request" not in result2
|
"proxy_server_request" not in result2
|
||||||
) # This will pass, showing the cache pollution
|
) # This will pass, showing the cache pollution
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_form_data():
|
|
||||||
"""
|
|
||||||
Test that get_form_data correctly handles form data with array notation.
|
|
||||||
Tests audio transcription parameters as a specific example.
|
|
||||||
"""
|
|
||||||
# Create a mock request with transcription form data
|
|
||||||
mock_request = MagicMock()
|
|
||||||
|
|
||||||
# Create mock form data with array notation for timestamp_granularities
|
|
||||||
mock_form_data = {
|
|
||||||
"file": "file_object", # In a real request this would be an UploadFile
|
|
||||||
"model": "gpt-4o-transcribe",
|
|
||||||
"include[]": "logprobs", # Array notation
|
|
||||||
"language": "en",
|
|
||||||
"prompt": "Transcribe this audio file",
|
|
||||||
"response_format": "json",
|
|
||||||
"stream": "false",
|
|
||||||
"temperature": "0.2",
|
|
||||||
"timestamp_granularities[]": "word", # First array item
|
|
||||||
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock the form method to return the test data
|
|
||||||
mock_request.form = AsyncMock(return_value=mock_form_data)
|
|
||||||
|
|
||||||
# Call the function being tested
|
|
||||||
result = await get_form_data(mock_request)
|
|
||||||
|
|
||||||
# Verify regular form fields are preserved
|
|
||||||
assert result["file"] == "file_object"
|
|
||||||
assert result["model"] == "gpt-4o-transcribe"
|
|
||||||
assert result["language"] == "en"
|
|
||||||
assert result["prompt"] == "Transcribe this audio file"
|
|
||||||
assert result["response_format"] == "json"
|
|
||||||
assert result["stream"] == "false"
|
|
||||||
assert result["temperature"] == "0.2"
|
|
||||||
|
|
||||||
# Verify array fields are correctly parsed
|
|
||||||
assert "include" in result
|
|
||||||
assert isinstance(result["include"], list)
|
|
||||||
assert "logprobs" in result["include"]
|
|
||||||
|
|
||||||
assert "timestamp_granularities" in result
|
|
||||||
assert isinstance(result["timestamp_granularities"], list)
|
|
||||||
# Note: In a real MultiDict, both values would be present
|
|
||||||
# But in our mock dictionary the second value overwrites the first
|
|
||||||
assert "segment" in result["timestamp_granularities"]
|
|
||||||
|
|
|
@ -947,8 +947,6 @@ class BaseLLMChatTest(ABC):
|
||||||
second_response.choices[0].message.content is not None
|
second_response.choices[0].message.content is not None
|
||||||
or second_response.choices[0].message.tool_calls is not None
|
or second_response.choices[0].message.tool_calls is not None
|
||||||
)
|
)
|
||||||
except litellm.ServiceUnavailableError:
|
|
||||||
pytest.skip("Model is overloaded")
|
|
||||||
except litellm.InternalServerError:
|
except litellm.InternalServerError:
|
||||||
pytest.skip("Model is overloaded")
|
pytest.skip("Model is overloaded")
|
||||||
except litellm.RateLimitError:
|
except litellm.RateLimitError:
|
||||||
|
|
|
@ -77,22 +77,6 @@ def test_basic_response():
|
||||||
)
|
)
|
||||||
print("basic response=", response)
|
print("basic response=", response)
|
||||||
|
|
||||||
# get the response
|
|
||||||
response = client.responses.retrieve(response.id)
|
|
||||||
print("GET response=", response)
|
|
||||||
|
|
||||||
|
|
||||||
# delete the response
|
|
||||||
delete_response = client.responses.delete(response.id)
|
|
||||||
print("DELETE response=", delete_response)
|
|
||||||
|
|
||||||
# try getting the response again, we should not get it back
|
|
||||||
get_response = client.responses.retrieve(response.id)
|
|
||||||
print("GET response after delete=", get_response)
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
get_response = client.responses.retrieve(response.id)
|
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_response():
|
def test_streaming_response():
|
||||||
client = get_test_client()
|
client = get_test_client()
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Optional
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -553,67 +553,9 @@ def test_initialize_router_endpoints():
|
||||||
assert hasattr(router, "aanthropic_messages")
|
assert hasattr(router, "aanthropic_messages")
|
||||||
assert hasattr(router, "aresponses")
|
assert hasattr(router, "aresponses")
|
||||||
assert hasattr(router, "responses")
|
assert hasattr(router, "responses")
|
||||||
assert hasattr(router, "aget_responses")
|
|
||||||
assert hasattr(router, "adelete_responses")
|
|
||||||
# Verify the endpoints are callable
|
# Verify the endpoints are callable
|
||||||
assert callable(router.amoderation)
|
assert callable(router.amoderation)
|
||||||
assert callable(router.aanthropic_messages)
|
assert callable(router.aanthropic_messages)
|
||||||
assert callable(router.aresponses)
|
assert callable(router.aresponses)
|
||||||
assert callable(router.responses)
|
assert callable(router.responses)
|
||||||
assert callable(router.aget_responses)
|
|
||||||
assert callable(router.adelete_responses)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_init_responses_api_endpoints():
|
|
||||||
"""
|
|
||||||
A simpler test for _init_responses_api_endpoints that focuses on the basic functionality
|
|
||||||
"""
|
|
||||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
|
||||||
# Create a router with a basic model
|
|
||||||
router = Router(
|
|
||||||
model_list=[
|
|
||||||
{
|
|
||||||
"model_name": "test-model",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "openai/test-model",
|
|
||||||
"api_key": "fake-api-key",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Just mock the _ageneric_api_call_with_fallbacks method
|
|
||||||
router._ageneric_api_call_with_fallbacks = AsyncMock()
|
|
||||||
|
|
||||||
# Add a mock implementation of _get_model_id_from_response_id to the Router instance
|
|
||||||
ResponsesAPIRequestUtils.get_model_id_from_response_id = MagicMock(return_value=None)
|
|
||||||
|
|
||||||
# Call without a response_id (no model extraction should happen)
|
|
||||||
await router._init_responses_api_endpoints(
|
|
||||||
original_function=AsyncMock(),
|
|
||||||
thread_id="thread_xyz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify _ageneric_api_call_with_fallbacks was called but model wasn't changed
|
|
||||||
first_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs
|
|
||||||
assert "model" not in first_call_kwargs
|
|
||||||
assert first_call_kwargs["thread_id"] == "thread_xyz"
|
|
||||||
|
|
||||||
# Reset the mock
|
|
||||||
router._ageneric_api_call_with_fallbacks.reset_mock()
|
|
||||||
|
|
||||||
# Change the return value for the second call
|
|
||||||
ResponsesAPIRequestUtils.get_model_id_from_response_id.return_value = "claude-3-sonnet"
|
|
||||||
|
|
||||||
# Call with a response_id
|
|
||||||
await router._init_responses_api_endpoints(
|
|
||||||
original_function=AsyncMock(),
|
|
||||||
response_id="resp_claude_123"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify model was updated in the kwargs
|
|
||||||
second_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs
|
|
||||||
assert second_call_kwargs["model"] == "claude-3-sonnet"
|
|
||||||
assert second_call_kwargs["response_id"] == "resp_claude_123"
|
|
||||||
|
|
||||||
|
|
|
@ -1157,14 +1157,3 @@ def test_cached_get_model_group_info(model_list):
|
||||||
# Verify the cache info shows hits
|
# Verify the cache info shows hits
|
||||||
cache_info = router._cached_get_model_group_info.cache_info()
|
cache_info = router._cached_get_model_group_info.cache_info()
|
||||||
assert cache_info.hits > 0 # Should have at least one cache hit
|
assert cache_info.hits > 0 # Should have at least one cache hit
|
||||||
|
|
||||||
|
|
||||||
def test_init_responses_api_endpoints(model_list):
|
|
||||||
"""Test if the '_init_responses_api_endpoints' function is working correctly"""
|
|
||||||
from typing import Callable
|
|
||||||
router = Router(model_list=model_list)
|
|
||||||
|
|
||||||
assert router.aget_responses is not None
|
|
||||||
assert isinstance(router.aget_responses, Callable)
|
|
||||||
assert router.adelete_responses is not None
|
|
||||||
assert isinstance(router.adelete_responses, Callable)
|
|
||||||
|
|
|
@ -151,7 +151,7 @@ export default function CreateKeyPage() {
|
||||||
if (redirectToLogin) {
|
if (redirectToLogin) {
|
||||||
window.location.href = (proxyBaseUrl || "") + "/sso/key/generate"
|
window.location.href = (proxyBaseUrl || "") + "/sso/key/generate"
|
||||||
}
|
}
|
||||||
}, [redirectToLogin])
|
}, [token, authLoading])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!token) {
|
if (!token) {
|
||||||
|
@ -223,7 +223,7 @@ export default function CreateKeyPage() {
|
||||||
}
|
}
|
||||||
}, [accessToken, userID, userRole]);
|
}, [accessToken, userID, userRole]);
|
||||||
|
|
||||||
if (authLoading || redirectToLogin) {
|
if (authLoading) {
|
||||||
return <LoadingScreen />
|
return <LoadingScreen />
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -450,8 +450,6 @@ export function AllKeysTable({
|
||||||
columns={columns.filter(col => col.id !== 'expander') as any}
|
columns={columns.filter(col => col.id !== 'expander') as any}
|
||||||
data={filteredKeys as any}
|
data={filteredKeys as any}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
loadingMessage="🚅 Loading keys..."
|
|
||||||
noDataMessage="No keys found"
|
|
||||||
getRowCanExpand={() => false}
|
getRowCanExpand={() => false}
|
||||||
renderSubComponent={() => <></>}
|
renderSubComponent={() => <></>}
|
||||||
/>
|
/>
|
||||||
|
|
|
@ -26,8 +26,6 @@ function DataTableWrapper({
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
renderSubComponent={renderSubComponent}
|
renderSubComponent={renderSubComponent}
|
||||||
getRowCanExpand={getRowCanExpand}
|
getRowCanExpand={getRowCanExpand}
|
||||||
loadingMessage="🚅 Loading tools..."
|
|
||||||
noDataMessage="No tools found"
|
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,8 +26,6 @@ interface DataTableProps<TData, TValue> {
|
||||||
expandedRequestId?: string | null;
|
expandedRequestId?: string | null;
|
||||||
onRowExpand?: (requestId: string | null) => void;
|
onRowExpand?: (requestId: string | null) => void;
|
||||||
setSelectedKeyIdInfoView?: (keyId: string | null) => void;
|
setSelectedKeyIdInfoView?: (keyId: string | null) => void;
|
||||||
loadingMessage?: string;
|
|
||||||
noDataMessage?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function DataTable<TData extends { request_id: string }, TValue>({
|
export function DataTable<TData extends { request_id: string }, TValue>({
|
||||||
|
@ -38,8 +36,6 @@ export function DataTable<TData extends { request_id: string }, TValue>({
|
||||||
isLoading = false,
|
isLoading = false,
|
||||||
expandedRequestId,
|
expandedRequestId,
|
||||||
onRowExpand,
|
onRowExpand,
|
||||||
loadingMessage = "🚅 Loading logs...",
|
|
||||||
noDataMessage = "No logs found",
|
|
||||||
}: DataTableProps<TData, TValue>) {
|
}: DataTableProps<TData, TValue>) {
|
||||||
const table = useReactTable({
|
const table = useReactTable({
|
||||||
data,
|
data,
|
||||||
|
@ -118,7 +114,7 @@ export function DataTable<TData extends { request_id: string }, TValue>({
|
||||||
<TableRow>
|
<TableRow>
|
||||||
<TableCell colSpan={columns.length} className="h-8 text-center">
|
<TableCell colSpan={columns.length} className="h-8 text-center">
|
||||||
<div className="text-center text-gray-500">
|
<div className="text-center text-gray-500">
|
||||||
<p>{loadingMessage}</p>
|
<p>🚅 Loading logs...</p>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
|
@ -151,7 +147,7 @@ export function DataTable<TData extends { request_id: string }, TValue>({
|
||||||
: <TableRow>
|
: <TableRow>
|
||||||
<TableCell colSpan={columns.length} className="h-8 text-center">
|
<TableCell colSpan={columns.length} className="h-8 text-center">
|
||||||
<div className="text-center text-gray-500">
|
<div className="text-center text-gray-500">
|
||||||
<p>{noDataMessage}</p>
|
<p>No logs found</p>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue