forked from phoenix/litellm-mirror
Litellm dev 11 08 2024 (#6658)
* fix(deepseek/chat): convert content list to str Fixes https://github.com/BerriAI/litellm/issues/6642 * test(test_deepseek_completion.py): implement base llm unit tests increase robustness across providers * fix(router.py): support content policy violation fallbacks with default fallbacks * fix(opentelemetry.py): refactor to move otel imports behing flag Fixes https://github.com/BerriAI/litellm/issues/6636 * fix(opentelemtry.py): close span on success completion * fix(user_api_key_auth.py): allow user_role to default to none * fix: mark flaky test * fix(opentelemetry.py): move otelconfig.from_env to inside the init prevent otel errors raised just by importing the litellm class * fix(user_api_key_auth.py): fix auth error
This commit is contained in:
parent
1bef6457c7
commit
73531f4815
19 changed files with 287 additions and 34 deletions
|
@ -1045,6 +1045,7 @@ from .llms.AzureOpenAI.azure import (
|
|||
|
||||
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
|
||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
||||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
|
||||
|
|
|
@ -16,6 +16,7 @@ from litellm.types.utils import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.sdk.trace.export import SpanExporter as _SpanExporter
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy._types import (
|
||||
|
@ -24,10 +25,12 @@ if TYPE_CHECKING:
|
|||
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
||||
|
||||
Span = _Span
|
||||
SpanExporter = _SpanExporter
|
||||
UserAPIKeyAuth = _UserAPIKeyAuth
|
||||
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
|
||||
else:
|
||||
Span = Any
|
||||
SpanExporter = Any
|
||||
UserAPIKeyAuth = Any
|
||||
ManagementEndpointLoggingPayload = Any
|
||||
|
||||
|
@ -44,7 +47,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
|
|||
|
||||
@dataclass
|
||||
class OpenTelemetryConfig:
|
||||
from opentelemetry.sdk.trace.export import SpanExporter
|
||||
|
||||
exporter: Union[str, SpanExporter] = "console"
|
||||
endpoint: Optional[str] = None
|
||||
|
@ -77,7 +79,7 @@ class OpenTelemetryConfig:
|
|||
class OpenTelemetry(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenTelemetryConfig = OpenTelemetryConfig.from_env(),
|
||||
config: Optional[OpenTelemetryConfig] = None,
|
||||
callback_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -85,6 +87,9 @@ class OpenTelemetry(CustomLogger):
|
|||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
|
||||
if config is None:
|
||||
config = OpenTelemetryConfig.from_env()
|
||||
|
||||
self.config = config
|
||||
self.OTEL_EXPORTER = self.config.exporter
|
||||
self.OTEL_ENDPOINT = self.config.endpoint
|
||||
|
@ -319,8 +324,8 @@ class OpenTelemetry(CustomLogger):
|
|||
|
||||
span.end(end_time=self._to_ns(end_time))
|
||||
|
||||
# if parent_otel_span is not None:
|
||||
# parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
if parent_otel_span is not None:
|
||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
@ -700,10 +705,10 @@ class OpenTelemetry(CustomLogger):
|
|||
TraceContextTextMapPropagator,
|
||||
)
|
||||
|
||||
verbose_logger.debug("OpenTelemetry: GOT A TRACEPARENT {}".format(_traceparent))
|
||||
propagator = TraceContextTextMapPropagator()
|
||||
_parent_context = propagator.extract(carrier={"traceparent": _traceparent})
|
||||
verbose_logger.debug("OpenTelemetry: PARENT CONTEXT {}".format(_parent_context))
|
||||
carrier = {"traceparent": _traceparent}
|
||||
_parent_context = propagator.extract(carrier=carrier)
|
||||
|
||||
return _parent_context
|
||||
|
||||
def _get_span_context(self, kwargs):
|
||||
|
|
|
@ -3,7 +3,7 @@ Support for gpt model family
|
|||
"""
|
||||
|
||||
import types
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
|
@ -163,3 +163,8 @@ class OpenAIGPTConfig:
|
|||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
return messages
|
||||
|
|
|
@ -108,7 +108,9 @@ class OpenAIO1Config(OpenAIGPTConfig):
|
|||
return True
|
||||
return False
|
||||
|
||||
def o1_prompt_factory(self, messages: List[AllMessageValues]):
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Handles limitations of O-1 model family.
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
|
|
|
@ -15,6 +15,7 @@ from pydantic import BaseModel
|
|||
from typing_extensions import overload, override
|
||||
|
||||
import litellm
|
||||
from litellm import LlmProviders
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
@ -24,6 +25,7 @@ from litellm.utils import (
|
|||
CustomStreamWrapper,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ProviderConfigManager,
|
||||
TextCompletionResponse,
|
||||
Usage,
|
||||
convert_to_model_response_object,
|
||||
|
@ -701,13 +703,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
messages=messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if (
|
||||
litellm.openAIO1Config.is_model_o1_reasoning_model(model=model)
|
||||
and messages is not None
|
||||
):
|
||||
messages = litellm.openAIO1Config.o1_prompt_factory(
|
||||
messages=messages,
|
||||
if messages is not None and custom_llm_provider is not None:
|
||||
provider_config = ProviderConfigManager.get_provider_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
messages = provider_config._transform_messages(messages)
|
||||
|
||||
for _ in range(
|
||||
2
|
||||
|
|
41
litellm/llms/deepseek/chat/transformation.py
Normal file
41
litellm/llms/deepseek/chat/transformation.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
Translates from OpenAI's `/v1/chat/completions` to DeepSeek's `/v1/chat/completions`
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ...prompt_templates.common_utils import (
|
||||
handle_messages_with_content_list_to_str_conversion,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekChatConfig(OpenAIGPTConfig):
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
DeepSeek does not support content in list format.
|
||||
"""
|
||||
messages = handle_messages_with_content_list_to_str_conversion(messages)
|
||||
return super()._transform_messages(messages)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("DEEPSEEK_API_BASE")
|
||||
or "https://api.deepseek.com/beta"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY")
|
||||
return api_base, dynamic_api_key
|
|
@ -24,6 +24,19 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
|
|||
)
|
||||
|
||||
|
||||
def handle_messages_with_content_list_to_str_conversion(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Handles messages with content list conversion
|
||||
"""
|
||||
for message in messages:
|
||||
texts = convert_content_list_to_str(message=message)
|
||||
if texts:
|
||||
message["content"] = texts
|
||||
return messages
|
||||
|
||||
|
||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||
"""
|
||||
- handles scenario where content is list and not string
|
||||
|
|
|
@ -59,10 +59,10 @@ model_list:
|
|||
timeout: 300
|
||||
stream_timeout: 60
|
||||
|
||||
# litellm_settings:
|
||||
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||
# callbacks: ["otel", "prometheus"]
|
||||
# default_redis_batch_cache_expiry: 10
|
||||
litellm_settings:
|
||||
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||
callbacks: ["otel", "prometheus"]
|
||||
default_redis_batch_cache_expiry: 10
|
||||
|
||||
# litellm_settings:
|
||||
# cache: True
|
||||
|
|
|
@ -703,12 +703,17 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
)
|
||||
|
||||
if is_master_key_valid:
|
||||
_user_api_key_obj = UserAPIKeyAuth(
|
||||
api_key=master_key,
|
||||
_user_api_key_obj = _return_user_api_key_auth_obj(
|
||||
user_obj=None,
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_id=litellm_proxy_admin_name,
|
||||
api_key=master_key,
|
||||
parent_otel_span=parent_otel_span,
|
||||
valid_token_dict={
|
||||
**end_user_params,
|
||||
"user_id": litellm_proxy_admin_name,
|
||||
},
|
||||
route=route,
|
||||
start_time=start_time,
|
||||
)
|
||||
await _cache_key_object(
|
||||
hashed_token=hash_token(master_key),
|
||||
|
@ -1229,7 +1234,9 @@ def _return_user_api_key_auth_obj(
|
|||
valid_token_dict: dict,
|
||||
route: str,
|
||||
start_time: datetime,
|
||||
user_role: Optional[LitellmUserRoles] = None,
|
||||
) -> UserAPIKeyAuth:
|
||||
traceback.print_stack()
|
||||
end_time = datetime.now()
|
||||
user_api_key_service_logger_obj.service_success_hook(
|
||||
service=ServiceTypes.AUTH,
|
||||
|
@ -1240,7 +1247,7 @@ def _return_user_api_key_auth_obj(
|
|||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
retrieved_user_role = (
|
||||
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
||||
user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
||||
)
|
||||
|
||||
user_api_key_kwargs = {
|
||||
|
|
|
@ -3558,6 +3558,15 @@ class Router:
|
|||
# Catch all - if any exceptions default to cooling down
|
||||
return True
|
||||
|
||||
def _has_default_fallbacks(self) -> bool:
|
||||
if self.fallbacks is None:
|
||||
return False
|
||||
for fallback in self.fallbacks:
|
||||
if isinstance(fallback, dict):
|
||||
if "*" in fallback:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_raise_content_policy_error(
|
||||
self, model: str, response: ModelResponse, kwargs: dict
|
||||
) -> bool:
|
||||
|
@ -3574,6 +3583,7 @@ class Router:
|
|||
content_policy_fallbacks = kwargs.get(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
|
||||
if content_policy_fallbacks is not None:
|
||||
fallback_model_group = None
|
||||
|
@ -3584,6 +3594,8 @@ class Router:
|
|||
|
||||
if fallback_model_group is not None:
|
||||
return True
|
||||
elif self._has_default_fallbacks(): # default fallbacks set
|
||||
return True
|
||||
|
||||
verbose_router_logger.info(
|
||||
"Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(
|
||||
|
|
|
@ -8252,3 +8252,22 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
|||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class ProviderConfigManager:
|
||||
@staticmethod
|
||||
def get_provider_config(
|
||||
model: str, provider: litellm.LlmProviders
|
||||
) -> OpenAIGPTConfig:
|
||||
"""
|
||||
Returns the provider config for a given provider.
|
||||
"""
|
||||
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
|
||||
return litellm.OpenAIO1Config()
|
||||
elif litellm.LlmProviders.DEEPSEEK == provider:
|
||||
return litellm.DeepSeekChatConfig()
|
||||
|
||||
return OpenAIGPTConfig()
|
||||
|
|
46
tests/llm_translation/base_llm_unit_tests.py
Normal file
46
tests/llm_translation/base_llm_unit_tests.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLLMChatTest(ABC):
|
||||
"""
|
||||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
"""Must return the base completion call args"""
|
||||
pass
|
||||
|
||||
def test_content_list_handling(self):
|
||||
"""Check if content list is supported by LLM API"""
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello, how are you?"}],
|
||||
}
|
||||
]
|
||||
response = litellm.completion(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
)
|
||||
assert response is not None
|
9
tests/llm_translation/test_deepseek_completion.py
Normal file
9
tests/llm_translation/test_deepseek_completion.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
# Test implementation
|
||||
class TestDeepSeekChatCompletion(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
}
|
|
@ -4526,6 +4526,7 @@ async def test_completion_ai21_chat():
|
|||
"stream",
|
||||
[False, True],
|
||||
)
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_completion_response_ratelimit_headers(model, stream):
|
||||
response = completion(
|
||||
model=model,
|
||||
|
|
41
tests/local_testing/test_opentelemetry_unit_tests.py
Normal file
41
tests/local_testing/test_opentelemetry_unit_tests.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
# What is this?
|
||||
## Unit tests for opentelemetry integration
|
||||
|
||||
# What is this?
|
||||
## Unit test for presidio pii masking
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_opentelemetry_integration():
|
||||
"""
|
||||
Unit test to confirm the parent otel span is ended
|
||||
"""
|
||||
|
||||
parent_otel_span = MagicMock()
|
||||
litellm.callbacks = ["otel"]
|
||||
|
||||
await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
mock_response="Hey!",
|
||||
metadata={"litellm_parent_otel_span": parent_otel_span},
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
parent_otel_span.end.assert_called_once()
|
|
@ -72,6 +72,19 @@ def test_litellm_proxy_server_config_no_general_settings():
|
|||
# Check if the response is successful
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "I'm alive!"
|
||||
|
||||
# Test /chat/completions
|
||||
response = requests.post(
|
||||
"http://localhost:4000/chat/completions",
|
||||
headers={"Authorization": "Bearer 1234567890"},
|
||||
json={
|
||||
"model": "test_openai_models",
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
except ImportError:
|
||||
pytest.fail("Failed to import litellm.proxy_server")
|
||||
except requests.ConnectionError:
|
||||
|
|
|
@ -1120,9 +1120,10 @@ async def test_client_side_fallbacks_list(sync_mode):
|
|||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize("content_filter_response_exception", [True, False])
|
||||
@pytest.mark.parametrize("fallback_type", ["model-specific", "default"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_content_policy_fallbacks(
|
||||
sync_mode, content_filter_response_exception
|
||||
sync_mode, content_filter_response_exception, fallback_type
|
||||
):
|
||||
os.environ["LITELLM_LOG"] = "DEBUG"
|
||||
|
||||
|
@ -1152,6 +1153,14 @@ async def test_router_content_policy_fallbacks(
|
|||
"mock_response": "This works!",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "my-default-fallback-model",
|
||||
"litellm_params": {
|
||||
"model": "openai/my-fake-model",
|
||||
"api_key": "",
|
||||
"mock_response": "This works 2!",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "my-general-model",
|
||||
"litellm_params": {
|
||||
|
@ -1169,9 +1178,14 @@ async def test_router_content_policy_fallbacks(
|
|||
},
|
||||
},
|
||||
],
|
||||
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}],
|
||||
fallbacks=[{"claude-2": ["my-general-model"]}],
|
||||
context_window_fallbacks=[{"claude-2": ["my-context-window-model"]}],
|
||||
content_policy_fallbacks=(
|
||||
[{"claude-2": ["my-fallback-model"]}]
|
||||
if fallback_type == "model-specific"
|
||||
else None
|
||||
),
|
||||
default_fallbacks=(
|
||||
["my-default-fallback-model"] if fallback_type == "default" else None
|
||||
),
|
||||
)
|
||||
|
||||
if sync_mode is True:
|
||||
|
|
|
@ -452,11 +452,17 @@ def test_update_usage(model_list):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"finish_reason, expected_error", [("content_filter", True), ("stop", False)]
|
||||
"finish_reason, expected_fallback", [("content_filter", True), ("stop", False)]
|
||||
)
|
||||
def test_should_raise_content_policy_error(model_list, finish_reason, expected_error):
|
||||
@pytest.mark.parametrize("fallback_type", ["model-specific", "default"])
|
||||
def test_should_raise_content_policy_error(
|
||||
model_list, finish_reason, expected_fallback, fallback_type
|
||||
):
|
||||
"""Test if the '_should_raise_content_policy_error' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
default_fallbacks=["gpt-4o"] if fallback_type == "default" else None,
|
||||
)
|
||||
|
||||
assert (
|
||||
router._should_raise_content_policy_error(
|
||||
|
@ -472,10 +478,14 @@ def test_should_raise_content_policy_error(model_list, finish_reason, expected_e
|
|||
usage={"total_tokens": 100},
|
||||
),
|
||||
kwargs={
|
||||
"content_policy_fallbacks": [{"gpt-3.5-turbo": "gpt-4o"}],
|
||||
"content_policy_fallbacks": (
|
||||
[{"gpt-3.5-turbo": "gpt-4o"}]
|
||||
if fallback_type == "model-specific"
|
||||
else None
|
||||
)
|
||||
},
|
||||
)
|
||||
is expected_error
|
||||
is expected_fallback
|
||||
)
|
||||
|
||||
|
||||
|
@ -1019,3 +1029,17 @@ async def test_pass_through_moderation_endpoint_factory(model_list):
|
|||
response = await router._pass_through_moderation_endpoint_factory(
|
||||
original_function=litellm.amoderation, input="this is valid good text"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"has_default_fallbacks, expected_result",
|
||||
[(True, True), (False, False)],
|
||||
)
|
||||
def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_result):
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
default_fallbacks=(
|
||||
["my-default-fallback-model"] if has_default_fallbacks else None
|
||||
),
|
||||
)
|
||||
assert router._has_default_fallbacks() is expected_result
|
||||
|
|
|
@ -362,7 +362,7 @@ async def test_team_info():
|
|||
|
||||
try:
|
||||
await get_team_info(session=session, get_team=team_id, call_key=key)
|
||||
pytest.fail(f"Expected call to fail")
|
||||
pytest.fail("Expected call to fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue