Litellm dev 11 08 2024 (#6658)

* fix(deepseek/chat): convert content list to str

Fixes https://github.com/BerriAI/litellm/issues/6642

* test(test_deepseek_completion.py): implement base llm unit tests

increase robustness across providers

* fix(router.py): support content policy violation fallbacks with default fallbacks

* fix(opentelemetry.py): refactor to move otel imports behing flag

Fixes https://github.com/BerriAI/litellm/issues/6636

* fix(opentelemtry.py): close span on success completion

* fix(user_api_key_auth.py): allow user_role to default to none

* fix: mark flaky test

* fix(opentelemetry.py): move otelconfig.from_env to inside the init

prevent otel errors raised just by importing the litellm class

* fix(user_api_key_auth.py): fix auth error
This commit is contained in:
Krish Dholakia 2024-11-08 22:07:17 +05:30 committed by GitHub
parent 1bef6457c7
commit 73531f4815
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 287 additions and 34 deletions

View file

@ -1045,6 +1045,7 @@ from .llms.AzureOpenAI.azure import (
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
from .llms.perplexity.chat.transformation import PerplexityChatConfig
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config

View file

@ -16,6 +16,7 @@ from litellm.types.utils import (
)
if TYPE_CHECKING:
from opentelemetry.sdk.trace.export import SpanExporter as _SpanExporter
from opentelemetry.trace import Span as _Span
from litellm.proxy._types import (
@ -24,10 +25,12 @@ if TYPE_CHECKING:
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span
SpanExporter = _SpanExporter
UserAPIKeyAuth = _UserAPIKeyAuth
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
else:
Span = Any
SpanExporter = Any
UserAPIKeyAuth = Any
ManagementEndpointLoggingPayload = Any
@ -44,7 +47,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
@dataclass
class OpenTelemetryConfig:
from opentelemetry.sdk.trace.export import SpanExporter
exporter: Union[str, SpanExporter] = "console"
endpoint: Optional[str] = None
@ -77,7 +79,7 @@ class OpenTelemetryConfig:
class OpenTelemetry(CustomLogger):
def __init__(
self,
config: OpenTelemetryConfig = OpenTelemetryConfig.from_env(),
config: Optional[OpenTelemetryConfig] = None,
callback_name: Optional[str] = None,
**kwargs,
):
@ -85,6 +87,9 @@ class OpenTelemetry(CustomLogger):
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
if config is None:
config = OpenTelemetryConfig.from_env()
self.config = config
self.OTEL_EXPORTER = self.config.exporter
self.OTEL_ENDPOINT = self.config.endpoint
@ -319,8 +324,8 @@ class OpenTelemetry(CustomLogger):
span.end(end_time=self._to_ns(end_time))
# if parent_otel_span is not None:
# parent_otel_span.end(end_time=self._to_ns(datetime.now()))
if parent_otel_span is not None:
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
@ -700,10 +705,10 @@ class OpenTelemetry(CustomLogger):
TraceContextTextMapPropagator,
)
verbose_logger.debug("OpenTelemetry: GOT A TRACEPARENT {}".format(_traceparent))
propagator = TraceContextTextMapPropagator()
_parent_context = propagator.extract(carrier={"traceparent": _traceparent})
verbose_logger.debug("OpenTelemetry: PARENT CONTEXT {}".format(_parent_context))
carrier = {"traceparent": _traceparent}
_parent_context = propagator.extract(carrier=carrier)
return _parent_context
def _get_span_context(self, kwargs):

View file

@ -3,7 +3,7 @@ Support for gpt model family
"""
import types
from typing import Optional, Union
from typing import List, Optional, Union
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
@ -163,3 +163,8 @@ class OpenAIGPTConfig:
model=model,
drop_params=drop_params,
)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages

View file

@ -108,7 +108,9 @@ class OpenAIO1Config(OpenAIGPTConfig):
return True
return False
def o1_prompt_factory(self, messages: List[AllMessageValues]):
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
Handles limitations of O-1 model family.
- modalities: image => drop param (if user opts in to dropping param)

View file

@ -15,6 +15,7 @@ from pydantic import BaseModel
from typing_extensions import overload, override
import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str
@ -24,6 +25,7 @@ from litellm.utils import (
CustomStreamWrapper,
Message,
ModelResponse,
ProviderConfigManager,
TextCompletionResponse,
Usage,
convert_to_model_response_object,
@ -701,13 +703,11 @@ class OpenAIChatCompletion(BaseLLM):
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if (
litellm.openAIO1Config.is_model_o1_reasoning_model(model=model)
and messages is not None
):
messages = litellm.openAIO1Config.o1_prompt_factory(
messages=messages,
if messages is not None and custom_llm_provider is not None:
provider_config = ProviderConfigManager.get_provider_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
messages = provider_config._transform_messages(messages)
for _ in range(
2

View file

@ -0,0 +1,41 @@
"""
Translates from OpenAI's `/v1/chat/completions` to DeepSeek's `/v1/chat/completions`
"""
import types
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel
import litellm
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
from ...prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
)
class DeepSeekChatConfig(OpenAIGPTConfig):
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
DeepSeek does not support content in list format.
"""
messages = handle_messages_with_content_list_to_str_conversion(messages)
return super()._transform_messages(messages)
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = (
api_base
or get_secret_str("DEEPSEEK_API_BASE")
or "https://api.deepseek.com/beta"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY")
return api_base, dynamic_api_key

View file

@ -24,6 +24,19 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
)
def handle_messages_with_content_list_to_str_conversion(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Handles messages with content list conversion
"""
for message in messages:
texts = convert_content_list_to_str(message=message)
if texts:
message["content"] = texts
return messages
def convert_content_list_to_str(message: AllMessageValues) -> str:
"""
- handles scenario where content is list and not string

View file

@ -59,10 +59,10 @@ model_list:
timeout: 300
stream_timeout: 60
# litellm_settings:
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
# callbacks: ["otel", "prometheus"]
# default_redis_batch_cache_expiry: 10
litellm_settings:
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
callbacks: ["otel", "prometheus"]
default_redis_batch_cache_expiry: 10
# litellm_settings:
# cache: True

View file

@ -703,12 +703,17 @@ async def user_api_key_auth( # noqa: PLR0915
)
if is_master_key_valid:
_user_api_key_obj = UserAPIKeyAuth(
api_key=master_key,
_user_api_key_obj = _return_user_api_key_auth_obj(
user_obj=None,
user_role=LitellmUserRoles.PROXY_ADMIN,
user_id=litellm_proxy_admin_name,
api_key=master_key,
parent_otel_span=parent_otel_span,
valid_token_dict={
**end_user_params,
"user_id": litellm_proxy_admin_name,
},
route=route,
start_time=start_time,
)
await _cache_key_object(
hashed_token=hash_token(master_key),
@ -1229,7 +1234,9 @@ def _return_user_api_key_auth_obj(
valid_token_dict: dict,
route: str,
start_time: datetime,
user_role: Optional[LitellmUserRoles] = None,
) -> UserAPIKeyAuth:
traceback.print_stack()
end_time = datetime.now()
user_api_key_service_logger_obj.service_success_hook(
service=ServiceTypes.AUTH,
@ -1240,7 +1247,7 @@ def _return_user_api_key_auth_obj(
parent_otel_span=parent_otel_span,
)
retrieved_user_role = (
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
)
user_api_key_kwargs = {

View file

@ -3558,6 +3558,15 @@ class Router:
# Catch all - if any exceptions default to cooling down
return True
def _has_default_fallbacks(self) -> bool:
if self.fallbacks is None:
return False
for fallback in self.fallbacks:
if isinstance(fallback, dict):
if "*" in fallback:
return True
return False
def _should_raise_content_policy_error(
self, model: str, response: ModelResponse, kwargs: dict
) -> bool:
@ -3574,6 +3583,7 @@ class Router:
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
if content_policy_fallbacks is not None:
fallback_model_group = None
@ -3584,6 +3594,8 @@ class Router:
if fallback_model_group is not None:
return True
elif self._has_default_fallbacks(): # default fallbacks set
return True
verbose_router_logger.info(
"Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(

View file

@ -8252,3 +8252,22 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
)
return messages
from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig
class ProviderConfigManager:
@staticmethod
def get_provider_config(
model: str, provider: litellm.LlmProviders
) -> OpenAIGPTConfig:
"""
Returns the provider config for a given provider.
"""
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
return litellm.OpenAIO1Config()
elif litellm.LlmProviders.DEEPSEEK == provider:
return litellm.DeepSeekChatConfig()
return OpenAIGPTConfig()

View file

@ -0,0 +1,46 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
# test_example.py
from abc import ABC, abstractmethod
class BaseLLMChatTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_completion_call_args(self) -> dict:
"""Must return the base completion call args"""
pass
def test_content_list_handling(self):
"""Check if content list is supported by LLM API"""
base_completion_call_args = self.get_base_completion_call_args()
messages = [
{
"role": "user",
"content": [{"type": "text", "text": "Hello, how are you?"}],
}
]
response = litellm.completion(
**base_completion_call_args,
messages=messages,
)
assert response is not None

View file

@ -0,0 +1,9 @@
from base_llm_unit_tests import BaseLLMChatTest
# Test implementation
class TestDeepSeekChatCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {
"model": "deepseek/deepseek-chat",
}

View file

@ -4526,6 +4526,7 @@ async def test_completion_ai21_chat():
"stream",
[False, True],
)
@pytest.mark.flaky(retries=3, delay=1)
def test_completion_response_ratelimit_headers(model, stream):
response = completion(
model=model,

View file

@ -0,0 +1,41 @@
# What is this?
## Unit tests for opentelemetry integration
# What is this?
## Unit test for presidio pii masking
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
import asyncio
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from unittest.mock import patch, MagicMock, AsyncMock
@pytest.mark.asyncio
async def test_opentelemetry_integration():
"""
Unit test to confirm the parent otel span is ended
"""
parent_otel_span = MagicMock()
litellm.callbacks = ["otel"]
await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Hey!",
metadata={"litellm_parent_otel_span": parent_otel_span},
)
await asyncio.sleep(1)
parent_otel_span.end.assert_called_once()

View file

@ -72,6 +72,19 @@ def test_litellm_proxy_server_config_no_general_settings():
# Check if the response is successful
assert response.status_code == 200
assert response.json() == "I'm alive!"
# Test /chat/completions
response = requests.post(
"http://localhost:4000/chat/completions",
headers={"Authorization": "Bearer 1234567890"},
json={
"model": "test_openai_models",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
},
)
assert response.status_code == 200
except ImportError:
pytest.fail("Failed to import litellm.proxy_server")
except requests.ConnectionError:

View file

@ -1120,9 +1120,10 @@ async def test_client_side_fallbacks_list(sync_mode):
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize("content_filter_response_exception", [True, False])
@pytest.mark.parametrize("fallback_type", ["model-specific", "default"])
@pytest.mark.asyncio
async def test_router_content_policy_fallbacks(
sync_mode, content_filter_response_exception
sync_mode, content_filter_response_exception, fallback_type
):
os.environ["LITELLM_LOG"] = "DEBUG"
@ -1152,6 +1153,14 @@ async def test_router_content_policy_fallbacks(
"mock_response": "This works!",
},
},
{
"model_name": "my-default-fallback-model",
"litellm_params": {
"model": "openai/my-fake-model",
"api_key": "",
"mock_response": "This works 2!",
},
},
{
"model_name": "my-general-model",
"litellm_params": {
@ -1169,9 +1178,14 @@ async def test_router_content_policy_fallbacks(
},
},
],
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}],
fallbacks=[{"claude-2": ["my-general-model"]}],
context_window_fallbacks=[{"claude-2": ["my-context-window-model"]}],
content_policy_fallbacks=(
[{"claude-2": ["my-fallback-model"]}]
if fallback_type == "model-specific"
else None
),
default_fallbacks=(
["my-default-fallback-model"] if fallback_type == "default" else None
),
)
if sync_mode is True:

View file

@ -452,11 +452,17 @@ def test_update_usage(model_list):
@pytest.mark.parametrize(
"finish_reason, expected_error", [("content_filter", True), ("stop", False)]
"finish_reason, expected_fallback", [("content_filter", True), ("stop", False)]
)
def test_should_raise_content_policy_error(model_list, finish_reason, expected_error):
@pytest.mark.parametrize("fallback_type", ["model-specific", "default"])
def test_should_raise_content_policy_error(
model_list, finish_reason, expected_fallback, fallback_type
):
"""Test if the '_should_raise_content_policy_error' function is working correctly"""
router = Router(model_list=model_list)
router = Router(
model_list=model_list,
default_fallbacks=["gpt-4o"] if fallback_type == "default" else None,
)
assert (
router._should_raise_content_policy_error(
@ -472,10 +478,14 @@ def test_should_raise_content_policy_error(model_list, finish_reason, expected_e
usage={"total_tokens": 100},
),
kwargs={
"content_policy_fallbacks": [{"gpt-3.5-turbo": "gpt-4o"}],
"content_policy_fallbacks": (
[{"gpt-3.5-turbo": "gpt-4o"}]
if fallback_type == "model-specific"
else None
)
},
)
is expected_error
is expected_fallback
)
@ -1019,3 +1029,17 @@ async def test_pass_through_moderation_endpoint_factory(model_list):
response = await router._pass_through_moderation_endpoint_factory(
original_function=litellm.amoderation, input="this is valid good text"
)
@pytest.mark.parametrize(
"has_default_fallbacks, expected_result",
[(True, True), (False, False)],
)
def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_result):
router = Router(
model_list=model_list,
default_fallbacks=(
["my-default-fallback-model"] if has_default_fallbacks else None
),
)
assert router._has_default_fallbacks() is expected_result

View file

@ -362,7 +362,7 @@ async def test_team_info():
try:
await get_team_info(session=session, get_team=team_id, call_key=key)
pytest.fail(f"Expected call to fail")
pytest.fail("Expected call to fail")
except Exception as e:
pass