mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(vertex_httpx.py): check if model supports system messages before sending separately
This commit is contained in:
parent
a80520004e
commit
3d9ef689e7
7 changed files with 190 additions and 73 deletions
|
@ -739,6 +739,7 @@ from .utils import (
|
|||
supports_function_calling,
|
||||
supports_parallel_function_calling,
|
||||
supports_vision,
|
||||
supports_system_messages,
|
||||
get_litellm_params,
|
||||
acreate,
|
||||
get_model_list,
|
||||
|
|
|
@ -12,7 +12,7 @@ if set_verbose is True:
|
|||
)
|
||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||
log_level = os.getenv("LITELLM_LOG", "ERROR")
|
||||
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
||||
numeric_level: str = getattr(logging, log_level.upper())
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(numeric_level)
|
||||
|
|
|
@ -18,6 +18,7 @@ import requests # type: ignore
|
|||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.litellm_logging
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||
|
@ -659,17 +660,29 @@ class VertexLLM(BaseLLM):
|
|||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
try:
|
||||
supports_system_message = litellm.supports_system_messages(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
supports_system_message = False
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_content_blocks: List[PartType] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block = PartType(text=message["content"])
|
||||
system_content_blocks.append(_system_content_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
if supports_system_message is True:
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block = PartType(text=message["content"])
|
||||
system_content_blocks.append(_system_content_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
|
|
|
@ -7,66 +7,86 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
||||
from typing_extensions import overload
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai, hashlib, json
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
import datetime as datetime_og
|
||||
import logging, asyncio
|
||||
import inspect, concurrent
|
||||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||
CustomHTTPTransport,
|
||||
AsyncCustomHTTPTransport,
|
||||
)
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
CustomStreamWrapper,
|
||||
get_utc_datetime,
|
||||
calculate_max_parallel_requests,
|
||||
_is_region_eu,
|
||||
)
|
||||
import asyncio
|
||||
import concurrent
|
||||
import copy
|
||||
from litellm._logging import verbose_router_logger
|
||||
import datetime as datetime_og
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
from litellm.types.router import (
|
||||
Deployment,
|
||||
ModelInfo,
|
||||
LiteLLM_Params,
|
||||
RouterErrors,
|
||||
updateDeployment,
|
||||
updateLiteLLMParams,
|
||||
RetryPolicy,
|
||||
AllowedFailsPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
ModelGroupInfo,
|
||||
AssistantsTypedDict,
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from typing_extensions import overload
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||
AsyncCustomHTTPTransport,
|
||||
CustomHTTPTransport,
|
||||
)
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
from litellm.router_utils.handle_error import send_llm_exception_alert
|
||||
from litellm.scheduler import FlowItem, Scheduler
|
||||
from litellm.types.llms.openai import (
|
||||
AsyncCursorPage,
|
||||
Assistant,
|
||||
Thread,
|
||||
AssistantToolParam,
|
||||
AsyncCursorPage,
|
||||
Attachment,
|
||||
OpenAIMessage,
|
||||
Run,
|
||||
AssistantToolParam,
|
||||
Thread,
|
||||
)
|
||||
from litellm.types.router import (
|
||||
AlertingConfig,
|
||||
AllowedFailsPolicy,
|
||||
AssistantsTypedDict,
|
||||
Deployment,
|
||||
DeploymentTypedDict,
|
||||
LiteLLM_Params,
|
||||
ModelGroupInfo,
|
||||
ModelInfo,
|
||||
RetryPolicy,
|
||||
RouterErrors,
|
||||
updateDeployment,
|
||||
updateLiteLLMParams,
|
||||
)
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
ModelResponse,
|
||||
_is_region_eu,
|
||||
calculate_max_parallel_requests,
|
||||
get_utc_datetime,
|
||||
)
|
||||
from litellm.scheduler import Scheduler, FlowItem
|
||||
from typing import Iterable
|
||||
from litellm.router_utils.handle_error import send_llm_exception_alert
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -3114,6 +3134,7 @@ class Router:
|
|||
|
||||
# proxy support
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
||||
|
@ -3800,6 +3821,7 @@ class Router:
|
|||
litellm_provider=llm_provider,
|
||||
mode="chat",
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=None,
|
||||
)
|
||||
|
||||
if model_group_info is None:
|
||||
|
|
|
@ -3392,15 +3392,28 @@ def test_completion_deep_infra_mistral():
|
|||
|
||||
|
||||
# Gemini tests
|
||||
def test_completion_gemini():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "gemini-1.0-pro",
|
||||
"gemini-1.5-pro",
|
||||
# "gemini-1.5-flash",
|
||||
],
|
||||
)
|
||||
def test_completion_gemini(model):
|
||||
litellm.set_verbose = True
|
||||
model_name = "gemini/gemini-1.5-pro-latest"
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
model_name = "gemini/{}".format(model)
|
||||
messages = [
|
||||
{"role": "system", "content": "Be a good bot!"},
|
||||
{"role": "user", "content": "Hey, how's it going?"},
|
||||
]
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages)
|
||||
# Add any assertions,here to check the response
|
||||
print(response)
|
||||
assert response.choices[0]["index"] == 0
|
||||
|
||||
assert False
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,13 +1,22 @@
|
|||
import json
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
|
||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
|
||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||
|
||||
|
@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False):
|
|||
]
|
||||
]
|
||||
supported_openai_params: Required[Optional[List[str]]]
|
||||
supports_system_messages: Optional[bool]
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
|
|
|
@ -1823,6 +1823,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
||||
Parameters:
|
||||
model (str): The model name to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||
"""
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if model_info.get("supports_system_messages", False) is True:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||
)
|
||||
|
||||
|
||||
def supports_function_calling(model: str) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
@ -1838,7 +1864,7 @@ def supports_function_calling(model: str) -> bool:
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_function_calling", False):
|
||||
if model_info.get("supports_function_calling", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -1862,7 +1888,7 @@ def supports_vision(model: str):
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_vision", False):
|
||||
if model_info.get("supports_vision", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -1884,7 +1910,7 @@ def supports_parallel_function_calling(model: str):
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_parallel_function_calling", False):
|
||||
if model_info.get("supports_parallel_function_calling", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -4319,14 +4345,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
)
|
||||
if custom_llm_provider == "huggingface":
|
||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||
return {
|
||||
"max_tokens": max_tokens, # type: ignore
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"litellm_provider": "huggingface",
|
||||
"mode": "chat",
|
||||
"supported_openai_params": supported_openai_params,
|
||||
}
|
||||
return ModelInfo(
|
||||
max_tokens=max_tokens, # type: ignore
|
||||
max_input_tokens=None,
|
||||
max_output_tokens=None,
|
||||
input_cost_per_token=0,
|
||||
output_cost_per_token=0,
|
||||
litellm_provider="huggingface",
|
||||
mode="chat",
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=None,
|
||||
)
|
||||
else:
|
||||
"""
|
||||
Check if: (in order of specificity)
|
||||
|
@ -4361,6 +4390,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
pass
|
||||
else:
|
||||
raise Exception
|
||||
return ModelInfo(
|
||||
max_tokens=_model_info.get("max_tokens", None),
|
||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
mode=_model_info.get("mode"),
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=_model_info.get(
|
||||
"supports_system_messages", None
|
||||
),
|
||||
)
|
||||
return _model_info
|
||||
elif split_model in litellm.model_cost:
|
||||
_model_info = litellm.model_cost[split_model]
|
||||
|
@ -4375,7 +4419,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
pass
|
||||
else:
|
||||
raise Exception
|
||||
return _model_info
|
||||
return ModelInfo(
|
||||
max_tokens=_model_info.get("max_tokens", None),
|
||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
mode=_model_info.get("mode"),
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=_model_info.get(
|
||||
"supports_system_messages", None
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue