Litellm dev 11 08 2024 (#6658)

* fix(deepseek/chat): convert content list to str

Fixes https://github.com/BerriAI/litellm/issues/6642

* test(test_deepseek_completion.py): implement base llm unit tests

increase robustness across providers

* fix(router.py): support content policy violation fallbacks with default fallbacks

* fix(opentelemetry.py): refactor to move otel imports behing flag

Fixes https://github.com/BerriAI/litellm/issues/6636

* fix(opentelemtry.py): close span on success completion

* fix(user_api_key_auth.py): allow user_role to default to none

* fix: mark flaky test

* fix(opentelemetry.py): move otelconfig.from_env to inside the init

prevent otel errors raised just by importing the litellm class

* fix(user_api_key_auth.py): fix auth error
This commit is contained in:
Krish Dholakia 2024-11-08 22:07:17 +05:30 committed by GitHub
parent 1bef6457c7
commit 73531f4815
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 287 additions and 34 deletions

View file

@ -1045,6 +1045,7 @@ from .llms.AzureOpenAI.azure import (
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
from .llms.lm_studio.chat.transformation import LMStudioChatConfig from .llms.lm_studio.chat.transformation import LMStudioChatConfig
from .llms.perplexity.chat.transformation import PerplexityChatConfig from .llms.perplexity.chat.transformation import PerplexityChatConfig
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config

View file

@ -16,6 +16,7 @@ from litellm.types.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.sdk.trace.export import SpanExporter as _SpanExporter
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
from litellm.proxy._types import ( from litellm.proxy._types import (
@ -24,10 +25,12 @@ if TYPE_CHECKING:
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span Span = _Span
SpanExporter = _SpanExporter
UserAPIKeyAuth = _UserAPIKeyAuth UserAPIKeyAuth = _UserAPIKeyAuth
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
else: else:
Span = Any Span = Any
SpanExporter = Any
UserAPIKeyAuth = Any UserAPIKeyAuth = Any
ManagementEndpointLoggingPayload = Any ManagementEndpointLoggingPayload = Any
@ -44,7 +47,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
@dataclass @dataclass
class OpenTelemetryConfig: class OpenTelemetryConfig:
from opentelemetry.sdk.trace.export import SpanExporter
exporter: Union[str, SpanExporter] = "console" exporter: Union[str, SpanExporter] = "console"
endpoint: Optional[str] = None endpoint: Optional[str] = None
@ -77,7 +79,7 @@ class OpenTelemetryConfig:
class OpenTelemetry(CustomLogger): class OpenTelemetry(CustomLogger):
def __init__( def __init__(
self, self,
config: OpenTelemetryConfig = OpenTelemetryConfig.from_env(), config: Optional[OpenTelemetryConfig] = None,
callback_name: Optional[str] = None, callback_name: Optional[str] = None,
**kwargs, **kwargs,
): ):
@ -85,6 +87,9 @@ class OpenTelemetry(CustomLogger):
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
if config is None:
config = OpenTelemetryConfig.from_env()
self.config = config self.config = config
self.OTEL_EXPORTER = self.config.exporter self.OTEL_EXPORTER = self.config.exporter
self.OTEL_ENDPOINT = self.config.endpoint self.OTEL_ENDPOINT = self.config.endpoint
@ -319,8 +324,8 @@ class OpenTelemetry(CustomLogger):
span.end(end_time=self._to_ns(end_time)) span.end(end_time=self._to_ns(end_time))
# if parent_otel_span is not None: if parent_otel_span is not None:
# parent_otel_span.end(end_time=self._to_ns(datetime.now())) parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_failure(self, kwargs, response_obj, start_time, end_time): def _handle_failure(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
@ -700,10 +705,10 @@ class OpenTelemetry(CustomLogger):
TraceContextTextMapPropagator, TraceContextTextMapPropagator,
) )
verbose_logger.debug("OpenTelemetry: GOT A TRACEPARENT {}".format(_traceparent))
propagator = TraceContextTextMapPropagator() propagator = TraceContextTextMapPropagator()
_parent_context = propagator.extract(carrier={"traceparent": _traceparent}) carrier = {"traceparent": _traceparent}
verbose_logger.debug("OpenTelemetry: PARENT CONTEXT {}".format(_parent_context)) _parent_context = propagator.extract(carrier=carrier)
return _parent_context return _parent_context
def _get_span_context(self, kwargs): def _get_span_context(self, kwargs):

View file

@ -3,7 +3,7 @@ Support for gpt model family
""" """
import types import types
from typing import Optional, Union from typing import List, Optional, Union
import litellm import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
@ -163,3 +163,8 @@ class OpenAIGPTConfig:
model=model, model=model,
drop_params=drop_params, drop_params=drop_params,
) )
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages

View file

@ -108,7 +108,9 @@ class OpenAIO1Config(OpenAIGPTConfig):
return True return True
return False return False
def o1_prompt_factory(self, messages: List[AllMessageValues]): def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
""" """
Handles limitations of O-1 model family. Handles limitations of O-1 model family.
- modalities: image => drop param (if user opts in to dropping param) - modalities: image => drop param (if user opts in to dropping param)

View file

@ -15,6 +15,7 @@ from pydantic import BaseModel
from typing_extensions import overload, override from typing_extensions import overload, override
import litellm import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
@ -24,6 +25,7 @@ from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
Message, Message,
ModelResponse, ModelResponse,
ProviderConfigManager,
TextCompletionResponse, TextCompletionResponse,
Usage, Usage,
convert_to_model_response_object, convert_to_model_response_object,
@ -701,13 +703,11 @@ class OpenAIChatCompletion(BaseLLM):
messages=messages, messages=messages,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
if ( if messages is not None and custom_llm_provider is not None:
litellm.openAIO1Config.is_model_o1_reasoning_model(model=model) provider_config = ProviderConfigManager.get_provider_config(
and messages is not None model=model, provider=LlmProviders(custom_llm_provider)
):
messages = litellm.openAIO1Config.o1_prompt_factory(
messages=messages,
) )
messages = provider_config._transform_messages(messages)
for _ in range( for _ in range(
2 2

View file

@ -0,0 +1,41 @@
"""
Translates from OpenAI's `/v1/chat/completions` to DeepSeek's `/v1/chat/completions`
"""
import types
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel
import litellm
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
from ...prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
)
class DeepSeekChatConfig(OpenAIGPTConfig):
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
DeepSeek does not support content in list format.
"""
messages = handle_messages_with_content_list_to_str_conversion(messages)
return super()._transform_messages(messages)
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = (
api_base
or get_secret_str("DEEPSEEK_API_BASE")
or "https://api.deepseek.com/beta"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY")
return api_base, dynamic_api_key

View file

@ -24,6 +24,19 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
) )
def handle_messages_with_content_list_to_str_conversion(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Handles messages with content list conversion
"""
for message in messages:
texts = convert_content_list_to_str(message=message)
if texts:
message["content"] = texts
return messages
def convert_content_list_to_str(message: AllMessageValues) -> str: def convert_content_list_to_str(message: AllMessageValues) -> str:
""" """
- handles scenario where content is list and not string - handles scenario where content is list and not string

View file

@ -59,10 +59,10 @@ model_list:
timeout: 300 timeout: 300
stream_timeout: 60 stream_timeout: 60
# litellm_settings: litellm_settings:
# fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
# callbacks: ["otel", "prometheus"] callbacks: ["otel", "prometheus"]
# default_redis_batch_cache_expiry: 10 default_redis_batch_cache_expiry: 10
# litellm_settings: # litellm_settings:
# cache: True # cache: True

View file

@ -703,12 +703,17 @@ async def user_api_key_auth( # noqa: PLR0915
) )
if is_master_key_valid: if is_master_key_valid:
_user_api_key_obj = UserAPIKeyAuth( _user_api_key_obj = _return_user_api_key_auth_obj(
api_key=master_key, user_obj=None,
user_role=LitellmUserRoles.PROXY_ADMIN, user_role=LitellmUserRoles.PROXY_ADMIN,
user_id=litellm_proxy_admin_name, api_key=master_key,
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
**end_user_params, valid_token_dict={
**end_user_params,
"user_id": litellm_proxy_admin_name,
},
route=route,
start_time=start_time,
) )
await _cache_key_object( await _cache_key_object(
hashed_token=hash_token(master_key), hashed_token=hash_token(master_key),
@ -1229,7 +1234,9 @@ def _return_user_api_key_auth_obj(
valid_token_dict: dict, valid_token_dict: dict,
route: str, route: str,
start_time: datetime, start_time: datetime,
user_role: Optional[LitellmUserRoles] = None,
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
traceback.print_stack()
end_time = datetime.now() end_time = datetime.now()
user_api_key_service_logger_obj.service_success_hook( user_api_key_service_logger_obj.service_success_hook(
service=ServiceTypes.AUTH, service=ServiceTypes.AUTH,
@ -1240,7 +1247,7 @@ def _return_user_api_key_auth_obj(
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
) )
retrieved_user_role = ( retrieved_user_role = (
_get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER user_role or _get_user_role(user_obj=user_obj) or LitellmUserRoles.INTERNAL_USER
) )
user_api_key_kwargs = { user_api_key_kwargs = {

View file

@ -3558,6 +3558,15 @@ class Router:
# Catch all - if any exceptions default to cooling down # Catch all - if any exceptions default to cooling down
return True return True
def _has_default_fallbacks(self) -> bool:
if self.fallbacks is None:
return False
for fallback in self.fallbacks:
if isinstance(fallback, dict):
if "*" in fallback:
return True
return False
def _should_raise_content_policy_error( def _should_raise_content_policy_error(
self, model: str, response: ModelResponse, kwargs: dict self, model: str, response: ModelResponse, kwargs: dict
) -> bool: ) -> bool:
@ -3574,6 +3583,7 @@ class Router:
content_policy_fallbacks = kwargs.get( content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks "content_policy_fallbacks", self.content_policy_fallbacks
) )
### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ### ### ONLY RAISE ERROR IF CP FALLBACK AVAILABLE ###
if content_policy_fallbacks is not None: if content_policy_fallbacks is not None:
fallback_model_group = None fallback_model_group = None
@ -3584,6 +3594,8 @@ class Router:
if fallback_model_group is not None: if fallback_model_group is not None:
return True return True
elif self._has_default_fallbacks(): # default fallbacks set
return True
verbose_router_logger.info( verbose_router_logger.info(
"Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format( "Content Policy Error occurred. No available fallbacks. Returning original response. model={}, content_policy_fallbacks={}".format(

View file

@ -8252,3 +8252,22 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
) )
return messages return messages
from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig
class ProviderConfigManager:
@staticmethod
def get_provider_config(
model: str, provider: litellm.LlmProviders
) -> OpenAIGPTConfig:
"""
Returns the provider config for a given provider.
"""
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
return litellm.OpenAIO1Config()
elif litellm.LlmProviders.DEEPSEEK == provider:
return litellm.DeepSeekChatConfig()
return OpenAIGPTConfig()

View file

@ -0,0 +1,46 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
# test_example.py
from abc import ABC, abstractmethod
class BaseLLMChatTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_completion_call_args(self) -> dict:
"""Must return the base completion call args"""
pass
def test_content_list_handling(self):
"""Check if content list is supported by LLM API"""
base_completion_call_args = self.get_base_completion_call_args()
messages = [
{
"role": "user",
"content": [{"type": "text", "text": "Hello, how are you?"}],
}
]
response = litellm.completion(
**base_completion_call_args,
messages=messages,
)
assert response is not None

View file

@ -0,0 +1,9 @@
from base_llm_unit_tests import BaseLLMChatTest
# Test implementation
class TestDeepSeekChatCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {
"model": "deepseek/deepseek-chat",
}

View file

@ -4526,6 +4526,7 @@ async def test_completion_ai21_chat():
"stream", "stream",
[False, True], [False, True],
) )
@pytest.mark.flaky(retries=3, delay=1)
def test_completion_response_ratelimit_headers(model, stream): def test_completion_response_ratelimit_headers(model, stream):
response = completion( response = completion(
model=model, model=model,

View file

@ -0,0 +1,41 @@
# What is this?
## Unit tests for opentelemetry integration
# What is this?
## Unit test for presidio pii masking
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
import asyncio
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from unittest.mock import patch, MagicMock, AsyncMock
@pytest.mark.asyncio
async def test_opentelemetry_integration():
"""
Unit test to confirm the parent otel span is ended
"""
parent_otel_span = MagicMock()
litellm.callbacks = ["otel"]
await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Hey!",
metadata={"litellm_parent_otel_span": parent_otel_span},
)
await asyncio.sleep(1)
parent_otel_span.end.assert_called_once()

View file

@ -72,6 +72,19 @@ def test_litellm_proxy_server_config_no_general_settings():
# Check if the response is successful # Check if the response is successful
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == "I'm alive!" assert response.json() == "I'm alive!"
# Test /chat/completions
response = requests.post(
"http://localhost:4000/chat/completions",
headers={"Authorization": "Bearer 1234567890"},
json={
"model": "test_openai_models",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
},
)
assert response.status_code == 200
except ImportError: except ImportError:
pytest.fail("Failed to import litellm.proxy_server") pytest.fail("Failed to import litellm.proxy_server")
except requests.ConnectionError: except requests.ConnectionError:

View file

@ -1120,9 +1120,10 @@ async def test_client_side_fallbacks_list(sync_mode):
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize("content_filter_response_exception", [True, False]) @pytest.mark.parametrize("content_filter_response_exception", [True, False])
@pytest.mark.parametrize("fallback_type", ["model-specific", "default"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_content_policy_fallbacks( async def test_router_content_policy_fallbacks(
sync_mode, content_filter_response_exception sync_mode, content_filter_response_exception, fallback_type
): ):
os.environ["LITELLM_LOG"] = "DEBUG" os.environ["LITELLM_LOG"] = "DEBUG"
@ -1152,6 +1153,14 @@ async def test_router_content_policy_fallbacks(
"mock_response": "This works!", "mock_response": "This works!",
}, },
}, },
{
"model_name": "my-default-fallback-model",
"litellm_params": {
"model": "openai/my-fake-model",
"api_key": "",
"mock_response": "This works 2!",
},
},
{ {
"model_name": "my-general-model", "model_name": "my-general-model",
"litellm_params": { "litellm_params": {
@ -1169,9 +1178,14 @@ async def test_router_content_policy_fallbacks(
}, },
}, },
], ],
content_policy_fallbacks=[{"claude-2": ["my-fallback-model"]}], content_policy_fallbacks=(
fallbacks=[{"claude-2": ["my-general-model"]}], [{"claude-2": ["my-fallback-model"]}]
context_window_fallbacks=[{"claude-2": ["my-context-window-model"]}], if fallback_type == "model-specific"
else None
),
default_fallbacks=(
["my-default-fallback-model"] if fallback_type == "default" else None
),
) )
if sync_mode is True: if sync_mode is True:

View file

@ -452,11 +452,17 @@ def test_update_usage(model_list):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"finish_reason, expected_error", [("content_filter", True), ("stop", False)] "finish_reason, expected_fallback", [("content_filter", True), ("stop", False)]
) )
def test_should_raise_content_policy_error(model_list, finish_reason, expected_error): @pytest.mark.parametrize("fallback_type", ["model-specific", "default"])
def test_should_raise_content_policy_error(
model_list, finish_reason, expected_fallback, fallback_type
):
"""Test if the '_should_raise_content_policy_error' function is working correctly""" """Test if the '_should_raise_content_policy_error' function is working correctly"""
router = Router(model_list=model_list) router = Router(
model_list=model_list,
default_fallbacks=["gpt-4o"] if fallback_type == "default" else None,
)
assert ( assert (
router._should_raise_content_policy_error( router._should_raise_content_policy_error(
@ -472,10 +478,14 @@ def test_should_raise_content_policy_error(model_list, finish_reason, expected_e
usage={"total_tokens": 100}, usage={"total_tokens": 100},
), ),
kwargs={ kwargs={
"content_policy_fallbacks": [{"gpt-3.5-turbo": "gpt-4o"}], "content_policy_fallbacks": (
[{"gpt-3.5-turbo": "gpt-4o"}]
if fallback_type == "model-specific"
else None
)
}, },
) )
is expected_error is expected_fallback
) )
@ -1019,3 +1029,17 @@ async def test_pass_through_moderation_endpoint_factory(model_list):
response = await router._pass_through_moderation_endpoint_factory( response = await router._pass_through_moderation_endpoint_factory(
original_function=litellm.amoderation, input="this is valid good text" original_function=litellm.amoderation, input="this is valid good text"
) )
@pytest.mark.parametrize(
"has_default_fallbacks, expected_result",
[(True, True), (False, False)],
)
def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_result):
router = Router(
model_list=model_list,
default_fallbacks=(
["my-default-fallback-model"] if has_default_fallbacks else None
),
)
assert router._has_default_fallbacks() is expected_result

View file

@ -362,7 +362,7 @@ async def test_team_info():
try: try:
await get_team_info(session=session, get_team=team_id, call_key=key) await get_team_info(session=session, get_team=team_id, call_key=key)
pytest.fail(f"Expected call to fail") pytest.fail("Expected call to fail")
except Exception as e: except Exception as e:
pass pass