forked from phoenix/litellm-mirror
Litellm dev 11 08 2024 (#6658)
* fix(deepseek/chat): convert content list to str Fixes https://github.com/BerriAI/litellm/issues/6642 * test(test_deepseek_completion.py): implement base llm unit tests increase robustness across providers * fix(router.py): support content policy violation fallbacks with default fallbacks * fix(opentelemetry.py): refactor to move otel imports behing flag Fixes https://github.com/BerriAI/litellm/issues/6636 * fix(opentelemtry.py): close span on success completion * fix(user_api_key_auth.py): allow user_role to default to none * fix: mark flaky test * fix(opentelemetry.py): move otelconfig.from_env to inside the init prevent otel errors raised just by importing the litellm class * fix(user_api_key_auth.py): fix auth error
This commit is contained in:
parent
1bef6457c7
commit
73531f4815
19 changed files with 287 additions and 34 deletions
|
@ -1045,6 +1045,7 @@ from .llms.AzureOpenAI.azure import (
|
||||||
|
|
||||||
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
|
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
|
||||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||||
|
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
||||||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||||
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
|
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
|
||||||
|
|
|
@ -16,6 +16,7 @@ from litellm.types.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.sdk.trace.export import SpanExporter as _SpanExporter
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
|
@ -24,10 +25,12 @@ if TYPE_CHECKING:
|
||||||
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
||||||
|
|
||||||
Span = _Span
|
Span = _Span
|
||||||
|
SpanExporter = _SpanExporter
|
||||||
UserAPIKeyAuth = _UserAPIKeyAuth
|
UserAPIKeyAuth = _UserAPIKeyAuth
|
||||||
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
|
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
SpanExporter = Any
|
||||||
UserAPIKeyAuth = Any
|
UserAPIKeyAuth = Any
|
||||||
ManagementEndpointLoggingPayload = Any
|
ManagementEndpointLoggingPayload = Any
|
||||||
|
|
||||||
|
@ -44,7 +47,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenTelemetryConfig:
|
class OpenTelemetryConfig:
|
||||||
from opentelemetry.sdk.trace.export import SpanExporter
|
|
||||||
|
|
||||||
exporter: Union[str, SpanExporter] = "console"
|
exporter: Union[str, SpanExporter] = "console"
|
||||||
endpoint: Optional[str] = None
|
endpoint: Optional[str] = None
|
||||||
|
@ -77,7 +79,7 @@ class OpenTelemetryConfig:
|
||||||
class OpenTelemetry(CustomLogger):
|
class OpenTelemetry(CustomLogger):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OpenTelemetryConfig = OpenTelemetryConfig.from_env(),
|
config: Optional[OpenTelemetryConfig] = None,
|
||||||
callback_name: Optional[str] = None,
|
callback_name: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -85,6 +87,9 @@ class OpenTelemetry(CustomLogger):
|
||||||
from opentelemetry.sdk.resources import Resource
|
from opentelemetry.sdk.resources import Resource
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = OpenTelemetryConfig.from_env()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.OTEL_EXPORTER = self.config.exporter
|
self.OTEL_EXPORTER = self.config.exporter
|
||||||
self.OTEL_ENDPOINT = self.config.endpoint
|
self.OTEL_ENDPOINT = self.config.endpoint
|
||||||
|
@ -319,8 +324,8 @@ class OpenTelemetry(CustomLogger):
|
||||||
|
|
||||||
span.end(end_time=self._to_ns(end_time))
|
span.end(end_time=self._to_ns(end_time))
|
||||||
|
|
||||||
# if parent_otel_span is not None:
|
if parent_otel_span is not None:
|
||||||
# parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||||
|
|
||||||
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
@ -700,10 +705,10 @@ class OpenTelemetry(CustomLogger):
|
||||||
TraceContextTextMapPropagator,
|
TraceContextTextMapPropagator,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_logger.debug("OpenTelemetry: GOT A TRACEPARENT {}".format(_traceparent))
|
|
||||||
propagator = TraceContextTextMapPropagator()
|
propagator = TraceContextTextMapPropagator()
|
||||||
_parent_context = propagator.extract(carrier={"traceparent": _traceparent})
|
carrier = {"traceparent": _traceparent}
|
||||||
verbose_logger.debug("OpenTelemetry: PARENT CONTEXT {}".format(_parent_context))
|
_parent_context = propagator.extract(carrier=carrier)
|
||||||
|
|
||||||
return _parent_context
|
return _parent_context
|
||||||
|
|
||||||
def _get_span_context(self, kwargs):
|
def _get_span_context(self, kwargs):
|
||||||
|
|
|
@ -3,7 +3,7 @@ Support for gpt model family
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import types
|
import types
|
||||||
from typing import Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||||
|
@ -163,3 +163,8 @@ class OpenAIGPTConfig:
|
||||||
model=model,
|
model=model,
|
||||||
drop_params=drop_params,
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
return messages
|
||||||
|
|
|
@ -108,7 +108,9 @@ class OpenAIO1Config(OpenAIGPTConfig):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def o1_prompt_factory(self, messages: List[AllMessageValues]):
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
"""
|
"""
|
||||||
Handles limitations of O-1 model family.
|
Handles limitations of O-1 model family.
|
||||||
- modalities: image => drop param (if user opts in to dropping param)
|
- modalities: image => drop param (if user opts in to dropping param)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from pydantic import BaseModel
|
||||||
from typing_extensions import overload, override
|
from typing_extensions import overload, override
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm import LlmProviders
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
@ -24,6 +25,7 @@ from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
Message,
|
Message,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
ProviderConfigManager,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
Usage,
|
Usage,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
|
@ -701,13 +703,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
if (
|
if messages is not None and custom_llm_provider is not None:
|
||||||
litellm.openAIO1Config.is_model_o1_reasoning_model(model=model)
|
provider_config = ProviderConfigManager.get_provider_config(
|
||||||
and messages is not None
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
):
|
|
||||||
messages = litellm.openAIO1Config.o1_prompt_factory(
|
|
||||||
messages=messages,
|
|
||||||
)
|
)
|
||||||
|
messages = provider_config._transform_messages(messages)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
2
|
2
|
||||||
|
|
41
litellm/llms/deepseek/chat/transformation.py
Normal file
41
litellm/llms/deepseek/chat/transformation.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
"""
|
||||||
|
Translates from OpenAI's `/v1/chat/completions` to DeepSeek's `/v1/chat/completions`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||||
|
|
||||||
|
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||||
|
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
from ...prompt_templates.common_utils import (
|
||||||
|
handle_messages_with_content_list_to_str_conversion,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekChatConfig(OpenAIGPTConfig):
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
DeepSeek does not support content in list format.
|
||||||
|
"""
|
||||||
|
messages = handle_messages_with_content_list_to_str_conversion(messages)
|
||||||
|
return super()._transform_messages(messages)
|
||||||
|
|
||||||
|
def _get_openai_compatible_provider_info(
|
||||||
|
self, api_base: Optional[str], api_key: Optional[str]
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or get_secret_str("DEEPSEEK_API_BASE")
|
||||||
|
or "https://api.deepseek.com/beta"
|
||||||
|
) # type: ignore
|
||||||
|
dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY")
|
||||||
|
return api_base, dynamic_api_key
|
|
@ -24,6 +24,19 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_messages_with_content_list_to_str_conversion(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Handles messages with content list conversion
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
texts = convert_content_list_to_str(message=message)
|
||||||
|
if texts:
|
||||||
|
message["content"] = texts
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||||
"""
|
"""
|
||||||
- handles scenario where content is list and not string
|
- handles scenario where content is list and not string
|
||||||
|
|
|
@ -59,10 +59,10 @@ model_list:
|
||||||
timeout: 300
|
timeout: 300
|
||||||
stream_timeout: 60
|
stream_timeout: 60
|
||||||
|
|
||||||
# litellm_settings:
|
litellm_settings:
|
||||||
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||||
# callbacks: ["otel", "prometheus"]
|
callbacks: ["otel", "prometheus"]
|
||||||
# default_redis_batch_cache_expiry: 10
|
default_redis_batch_cache_expiry: 10
|
||||||
|
|
||||||
# litellm_settings:
|
# litellm_settings:
|
||||||
# cache: True
|
# cache: True
|
||||||
|
|
|
@ -703,12 +703,17 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
_user_api_key_obj = UserAPIKeyAuth(
|
_user_api_key_obj = _return_user_api_key_auth_obj(
|
||||||
api_key=master_key,
|
user_obj=None,
|
||||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
user_id=litellm_proxy_admin_name,
|
api_key=master_key,
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
**end_user_params,
|
valid_token_dict={
|
||||||
|
**end_user_params,
|
||||||
|
"user_id": litellm_proxy_admin_name,
|
||||||
|
},
|
||||||
|
route=route,
|
||||||
|
start_time=start_time,
|
||||||
)
|
)
|
||||||
await _cache_key_object(
|
await _cache_key_object(
|
||||||
hashed_token=hash_token(master_key),
|
hashed_token=hash_token(master_key),
|
||||||
|
@ -1229,7 +1234,9 @@ def _return_user_api_key_auth_obj(
|
||||||
valid_token_dict: dict,
|
valid_token_dict: dict,
|
||||||
route: str,
|
route: str,
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
|
user_role: Optional[LitellmUserRoles] = None,
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
|
traceback.print_stack()
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
user_api_key_service_logger_obj.service_success_hook(
|
user_api_key_service_logger_obj.service_success_hook(
|
||||||
service=ServiceTypes.AUTH,
|
service=ServiceTypes.AUTH,
|
||||||
|
@ -1240,7 +1247,7 @@ def _return_user_api_key_auth_obj(
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
)
|
)
|
||||||
retrieved_user_role = (
|
retrieved_user_role = (
|
||||||
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
|
||||||
)
|
)
|
||||||
|
|
||||||
user_api_key_kwargs = {
|
user_api_key_kwargs = {
|
||||||
|
|
|
@ -3558,6 +3558,15 @@ class Router:
|
||||||
# Catch all - if any exceptions default to cooling down
|
# Catch all - if any exceptions default to cooling down
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _has_default_fallbacks(self) -> bool:
|
||||||
|
if self.fallbacks is None:
|
||||||
|
return False
|
||||||
|
for fallback in self.fallbacks:
|
||||||
|
if isinstance(fallback, dict):
|
||||||
|
if "*" in fallback:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _should_raise_content_policy_error(
|
def _should_raise_content_policy_error(
|
||||||
self, model: str, response: ModelResponse, kwargs: dict
|
self, model: str, response: ModelResponse, kwargs: dict
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
@ -3574,6 +3583,7 @@ class Router:
|
||||||
content_policy_fallbacks = kwargs.get(
|
content_policy_fallbacks = kwargs.get(
|
||||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
|
### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
|
||||||
if content_policy_fallbacks is not None:
|
if content_policy_fallbacks is not None:
|
||||||
fallback_model_group = None
|
fallback_model_group = None
|
||||||
|
@ -3584,6 +3594,8 @@ class Router:
|
||||||
|
|
||||||
if fallback_model_group is not None:
|
if fallback_model_group is not None:
|
||||||
return True
|
return True
|
||||||
|
elif self._has_default_fallbacks(): # default fallbacks set
|
||||||
|
return True
|
||||||
|
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
"Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(
|
"Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(
|
||||||
|
|
|
@ -8252,3 +8252,22 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
||||||
)
|
)
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigManager:
|
||||||
|
@staticmethod
|
||||||
|
def get_provider_config(
|
||||||
|
model: str, provider: litellm.LlmProviders
|
||||||
|
) -> OpenAIGPTConfig:
|
||||||
|
"""
|
||||||
|
Returns the provider config for a given provider.
|
||||||
|
"""
|
||||||
|
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
|
||||||
|
return litellm.OpenAIO1Config()
|
||||||
|
elif litellm.LlmProviders.DEEPSEEK == provider:
|
||||||
|
return litellm.DeepSeekChatConfig()
|
||||||
|
|
||||||
|
return OpenAIGPTConfig()
|
||||||
|
|
46
tests/llm_translation/base_llm_unit_tests.py
Normal file
46
tests/llm_translation/base_llm_unit_tests.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm.exceptions import BadRequestError
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
|
|
||||||
|
# test_example.py
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMChatTest(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base test class that enforces a common test across all test classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
"""Must return the base completion call args"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_content_list_handling(self):
|
||||||
|
"""Check if content list is supported by LLM API"""
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Hello, how are you?"}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = litellm.completion(
|
||||||
|
**base_completion_call_args,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
assert response is not None
|
9
tests/llm_translation/test_deepseek_completion.py
Normal file
9
tests/llm_translation/test_deepseek_completion.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
|
||||||
|
|
||||||
|
# Test implementation
|
||||||
|
class TestDeepSeekChatCompletion(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "deepseek/deepseek-chat",
|
||||||
|
}
|
|
@ -4526,6 +4526,7 @@ async def test_completion_ai21_chat():
|
||||||
"stream",
|
"stream",
|
||||||
[False, True],
|
[False, True],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
def test_completion_response_ratelimit_headers(model, stream):
|
def test_completion_response_ratelimit_headers(model, stream):
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
41
tests/local_testing/test_opentelemetry_unit_tests.py
Normal file
41
tests/local_testing/test_opentelemetry_unit_tests.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for opentelemetry integration
|
||||||
|
|
||||||
|
# What is this?
|
||||||
|
## Unit test for presidio pii masking
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_opentelemetry_integration():
|
||||||
|
"""
|
||||||
|
Unit test to confirm the parent otel span is ended
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_otel_span = MagicMock()
|
||||||
|
litellm.callbacks = ["otel"]
|
||||||
|
|
||||||
|
await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
mock_response="Hey!",
|
||||||
|
metadata={"litellm_parent_otel_span": parent_otel_span},
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
parent_otel_span.end.assert_called_once()
|
|
@ -72,6 +72,19 @@ def test_litellm_proxy_server_config_no_general_settings():
|
||||||
# Check if the response is successful
|
# Check if the response is successful
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == "I'm alive!"
|
assert response.json() == "I'm alive!"
|
||||||
|
|
||||||
|
# Test /chat/completions
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:4000/chat/completions",
|
||||||
|
headers={"Authorization": "Bearer 1234567890"},
|
||||||
|
json={
|
||||||
|
"model": "test_openai_models",
|
||||||
|
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pytest.fail("Failed to import litellm.proxy_server")
|
pytest.fail("Failed to import litellm.proxy_server")
|
||||||
except requests.ConnectionError:
|
except requests.ConnectionError:
|
||||||
|
|
|
@ -1120,9 +1120,10 @@ async def test_client_side_fallbacks_list(sync_mode):
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.parametrize("content_filter_response_exception", [True, False])
|
@pytest.mark.parametrize("content_filter_response_exception", [True, False])
|
||||||
|
@pytest.mark.parametrize("fallback_type", ["model-specific", "default"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_content_policy_fallbacks(
|
async def test_router_content_policy_fallbacks(
|
||||||
sync_mode, content_filter_response_exception
|
sync_mode, content_filter_response_exception, fallback_type
|
||||||
):
|
):
|
||||||
os.environ["LITELLM_LOG"] = "DEBUG"
|
os.environ["LITELLM_LOG"] = "DEBUG"
|
||||||
|
|
||||||
|
@ -1152,6 +1153,14 @@ async def test_router_content_policy_fallbacks(
|
||||||
"mock_response": "This works!",
|
"mock_response": "This works!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"model_name": "my-default-fallback-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/my-fake-model",
|
||||||
|
"api_key": "",
|
||||||
|
"mock_response": "This works 2!",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"model_name": "my-general-model",
|
"model_name": "my-general-model",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
|
@ -1169,9 +1178,14 @@ async def test_router_content_policy_fallbacks(
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}],
|
content_policy_fallbacks=(
|
||||||
fallbacks=[{"claude-2": ["my-general-model"]}],
|
[{"claude-2": ["my-fallback-model"]}]
|
||||||
context_window_fallbacks=[{"claude-2": ["my-context-window-model"]}],
|
if fallback_type == "model-specific"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
default_fallbacks=(
|
||||||
|
["my-default-fallback-model"] if fallback_type == "default" else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if sync_mode is True:
|
if sync_mode is True:
|
||||||
|
|
|
@ -452,11 +452,17 @@ def test_update_usage(model_list):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"finish_reason, expected_error", [("content_filter", True), ("stop", False)]
|
"finish_reason, expected_fallback", [("content_filter", True), ("stop", False)]
|
||||||
)
|
)
|
||||||
def test_should_raise_content_policy_error(model_list, finish_reason, expected_error):
|
@pytest.mark.parametrize("fallback_type", ["model-specific", "default"])
|
||||||
|
def test_should_raise_content_policy_error(
|
||||||
|
model_list, finish_reason, expected_fallback, fallback_type
|
||||||
|
):
|
||||||
"""Test if the '_should_raise_content_policy_error' function is working correctly"""
|
"""Test if the '_should_raise_content_policy_error' function is working correctly"""
|
||||||
router = Router(model_list=model_list)
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
default_fallbacks=["gpt-4o"] if fallback_type == "default" else None,
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
router._should_raise_content_policy_error(
|
router._should_raise_content_policy_error(
|
||||||
|
@ -472,10 +478,14 @@ def test_should_raise_content_policy_error(model_list, finish_reason, expected_e
|
||||||
usage={"total_tokens": 100},
|
usage={"total_tokens": 100},
|
||||||
),
|
),
|
||||||
kwargs={
|
kwargs={
|
||||||
"content_policy_fallbacks": [{"gpt-3.5-turbo": "gpt-4o"}],
|
"content_policy_fallbacks": (
|
||||||
|
[{"gpt-3.5-turbo": "gpt-4o"}]
|
||||||
|
if fallback_type == "model-specific"
|
||||||
|
else None
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
is expected_error
|
is expected_fallback
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1019,3 +1029,17 @@ async def test_pass_through_moderation_endpoint_factory(model_list):
|
||||||
response = await router._pass_through_moderation_endpoint_factory(
|
response = await router._pass_through_moderation_endpoint_factory(
|
||||||
original_function=litellm.amoderation, input="this is valid good text"
|
original_function=litellm.amoderation, input="this is valid good text"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"has_default_fallbacks, expected_result",
|
||||||
|
[(True, True), (False, False)],
|
||||||
|
)
|
||||||
|
def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_result):
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
default_fallbacks=(
|
||||||
|
["my-default-fallback-model"] if has_default_fallbacks else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert router._has_default_fallbacks() is expected_result
|
||||||
|
|
|
@ -362,7 +362,7 @@ async def test_team_info():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await get_team_info(session=session, get_team=team_id, call_key=key)
|
await get_team_info(session=session, get_team=team_id, call_key=key)
|
||||||
pytest.fail(f"Expected call to fail")
|
pytest.fail("Expected call to fail")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue