fix(vertex_httpx.py): check if model supports system messages before sending separately

This commit is contained in:
Krrish Dholakia 2024-06-17 17:30:38 -07:00
parent a80520004e
commit 3d9ef689e7
7 changed files with 190 additions and 73 deletions

View file

@ -739,6 +739,7 @@ from .utils import (
supports_function_calling,
supports_parallel_function_calling,
supports_vision,
supports_system_messages,
get_litellm_params,
acreate,
get_model_list,

View file

@ -12,7 +12,7 @@ if set_verbose is True:
)
json_logs = bool(os.getenv("JSON_LOGS", False))
# Create a handler for the logger (you may need to adapt this based on your needs)
log_level = os.getenv("LITELLM_LOG", "ERROR")
log_level = os.getenv("LITELLM_LOG", "DEBUG")
numeric_level: str = getattr(logging, log_level.upper())
handler = logging.StreamHandler()
handler.setLevel(numeric_level)

View file

@ -18,6 +18,7 @@ import requests # type: ignore
import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
@ -659,9 +660,21 @@ class VertexLLM(BaseLLM):
)
## TRANSFORMATION ##
try:
supports_system_message = litellm.supports_system_messages(
model=model, custom_llm_provider=custom_llm_provider
)
except Exception as e:
verbose_logger.error(
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
str(e)
)
)
supports_system_message = False
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[PartType] = []
if supports_system_message is True:
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = PartType(text=message["content"])

View file

@ -7,66 +7,86 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache
import datetime as datetime_og
import logging, asyncio
import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport,
AsyncCustomHTTPTransport,
)
from litellm.utils import (
ModelResponse,
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
_is_region_eu,
)
import asyncio
import concurrent
import copy
from litellm._logging import verbose_router_logger
import datetime as datetime_og
import hashlib
import inspect
import json
import logging
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.router import (
Deployment,
ModelInfo,
LiteLLM_Params,
RouterErrors,
updateDeployment,
updateLiteLLMParams,
RetryPolicy,
AllowedFailsPolicy,
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
AssistantsTypedDict,
import random
import threading
import time
import traceback
import uuid
from collections import defaultdict
from datetime import datetime
from typing import (
Any,
BinaryIO,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
)
import httpx
import openai
from openai import AsyncOpenAI
from typing_extensions import overload
import litellm
from litellm._logging import verbose_router_logger
from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.custom_httpx.azure_dall_e_2 import (
AsyncCustomHTTPTransport,
CustomHTTPTransport,
)
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_utils.handle_error import send_llm_exception_alert
from litellm.scheduler import FlowItem, Scheduler
from litellm.types.llms.openai import (
AsyncCursorPage,
Assistant,
Thread,
AssistantToolParam,
AsyncCursorPage,
Attachment,
OpenAIMessage,
Run,
AssistantToolParam,
Thread,
)
from litellm.types.router import (
AlertingConfig,
AllowedFailsPolicy,
AssistantsTypedDict,
Deployment,
DeploymentTypedDict,
LiteLLM_Params,
ModelGroupInfo,
ModelInfo,
RetryPolicy,
RouterErrors,
updateDeployment,
updateLiteLLMParams,
)
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.utils import (
CustomStreamWrapper,
ModelResponse,
_is_region_eu,
calculate_max_parallel_requests,
get_utc_datetime,
)
from litellm.scheduler import Scheduler, FlowItem
from typing import Iterable
from litellm.router_utils.handle_error import send_llm_exception_alert
class Router:
@ -3114,6 +3134,7 @@ class Router:
# proxy support
import os
import httpx
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
@ -3800,6 +3821,7 @@ class Router:
litellm_provider=llm_provider,
mode="chat",
supported_openai_params=supported_openai_params,
supports_system_messages=None,
)
if model_group_info is None:

View file

@ -3392,15 +3392,28 @@ def test_completion_deep_infra_mistral():
# Gemini tests
def test_completion_gemini():
@pytest.mark.parametrize(
"model",
[
# "gemini-1.0-pro",
"gemini-1.5-pro",
# "gemini-1.5-flash",
],
)
def test_completion_gemini(model):
litellm.set_verbose = True
model_name = "gemini/gemini-1.5-pro-latest"
messages = [{"role": "user", "content": "Hey, how's it going?"}]
model_name = "gemini/{}".format(model)
messages = [
{"role": "system", "content": "Be a good bot!"},
{"role": "user", "content": "Hey, how's it going?"},
]
try:
response = completion(model=model_name, messages=messages)
# Add any assertions,here to check the response
print(response)
assert response.choices[0]["index"] == 0
assert False
except litellm.APIError as e:
pass
except Exception as e:

View file

@ -1,13 +1,22 @@
import json
import time
import uuid
import json
import time
import uuid
from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict
from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False):
]
]
supported_openai_params: Required[Optional[List[str]]]
supports_system_messages: Optional[bool]
class GenericStreamingChunk(TypedDict):

View file

@ -1823,6 +1823,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
return False
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
Parameters:
model (str): The model name to be checked.
Returns:
bool: True if the model supports function calling, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
"""
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("supports_system_messages", False) is True:
return True
return False
except Exception:
raise Exception(
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
)
def supports_function_calling(model: str) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
@ -1838,7 +1864,7 @@ def supports_function_calling(model: str) -> bool:
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_function_calling", False):
if model_info.get("supports_function_calling", False) is True:
return True
return False
else:
@ -1862,7 +1888,7 @@ def supports_vision(model: str):
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_vision", False):
if model_info.get("supports_vision", False) is True:
return True
return False
else:
@ -1884,7 +1910,7 @@ def supports_parallel_function_calling(model: str):
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_parallel_function_calling", False):
if model_info.get("supports_parallel_function_calling", False) is True:
return True
return False
else:
@ -4319,14 +4345,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
)
if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model)
return {
"max_tokens": max_tokens, # type: ignore
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "huggingface",
"mode": "chat",
"supported_openai_params": supported_openai_params,
}
return ModelInfo(
max_tokens=max_tokens, # type: ignore
max_input_tokens=None,
max_output_tokens=None,
input_cost_per_token=0,
output_cost_per_token=0,
litellm_provider="huggingface",
mode="chat",
supported_openai_params=supported_openai_params,
supports_system_messages=None,
)
else:
"""
Check if: (in order of specificity)
@ -4361,6 +4390,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
return _model_info
elif split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model]
@ -4375,7 +4419,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return _model_info
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
else:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"