mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(vertex_httpx.py): check if model supports system messages before sending separately
This commit is contained in:
parent
18d9dcc4db
commit
cc1ec55e5b
7 changed files with 190 additions and 73 deletions
|
@ -739,6 +739,7 @@ from .utils import (
|
||||||
supports_function_calling,
|
supports_function_calling,
|
||||||
supports_parallel_function_calling,
|
supports_parallel_function_calling,
|
||||||
supports_vision,
|
supports_vision,
|
||||||
|
supports_system_messages,
|
||||||
get_litellm_params,
|
get_litellm_params,
|
||||||
acreate,
|
acreate,
|
||||||
get_model_list,
|
get_model_list,
|
||||||
|
|
|
@ -12,7 +12,7 @@ if set_verbose is True:
|
||||||
)
|
)
|
||||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||||
log_level = os.getenv("LITELLM_LOG", "ERROR")
|
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
||||||
numeric_level: str = getattr(logging, log_level.upper())
|
numeric_level: str = getattr(logging, log_level.upper())
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setLevel(numeric_level)
|
handler.setLevel(numeric_level)
|
||||||
|
|
|
@ -18,6 +18,7 @@ import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.litellm_logging
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||||
|
@ -659,9 +660,21 @@ class VertexLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
try:
|
||||||
|
supports_system_message = litellm.supports_system_messages(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
supports_system_message = False
|
||||||
# Separate system prompt from rest of message
|
# Separate system prompt from rest of message
|
||||||
system_prompt_indices = []
|
system_prompt_indices = []
|
||||||
system_content_blocks: List[PartType] = []
|
system_content_blocks: List[PartType] = []
|
||||||
|
if supports_system_message is True:
|
||||||
for idx, message in enumerate(messages):
|
for idx, message in enumerate(messages):
|
||||||
if message["role"] == "system":
|
if message["role"] == "system":
|
||||||
_system_content_block = PartType(text=message["content"])
|
_system_content_block = PartType(text=message["content"])
|
||||||
|
|
|
@ -7,66 +7,86 @@
|
||||||
#
|
#
|
||||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
import copy, httpx
|
import asyncio
|
||||||
from datetime import datetime
|
import concurrent
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
|
||||||
from typing_extensions import overload
|
|
||||||
import random, threading, time, traceback, uuid
|
|
||||||
import litellm, openai, hashlib, json
|
|
||||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
|
||||||
import datetime as datetime_og
|
|
||||||
import logging, asyncio
|
|
||||||
import inspect, concurrent
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
from collections import defaultdict
|
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
|
||||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
|
||||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
|
||||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
|
||||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
|
||||||
CustomHTTPTransport,
|
|
||||||
AsyncCustomHTTPTransport,
|
|
||||||
)
|
|
||||||
from litellm.utils import (
|
|
||||||
ModelResponse,
|
|
||||||
CustomStreamWrapper,
|
|
||||||
get_utc_datetime,
|
|
||||||
calculate_max_parallel_requests,
|
|
||||||
_is_region_eu,
|
|
||||||
)
|
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
import datetime as datetime_og
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
import random
|
||||||
from litellm.types.router import (
|
import threading
|
||||||
Deployment,
|
import time
|
||||||
ModelInfo,
|
import traceback
|
||||||
LiteLLM_Params,
|
import uuid
|
||||||
RouterErrors,
|
from collections import defaultdict
|
||||||
updateDeployment,
|
from datetime import datetime
|
||||||
updateLiteLLMParams,
|
from typing import (
|
||||||
RetryPolicy,
|
Any,
|
||||||
AllowedFailsPolicy,
|
BinaryIO,
|
||||||
AlertingConfig,
|
Dict,
|
||||||
DeploymentTypedDict,
|
Iterable,
|
||||||
ModelGroupInfo,
|
List,
|
||||||
AssistantsTypedDict,
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import openai
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from typing_extensions import overload
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_router_logger
|
||||||
|
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||||
|
AsyncCustomHTTPTransport,
|
||||||
|
CustomHTTPTransport,
|
||||||
|
)
|
||||||
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
|
from litellm.router_utils.handle_error import send_llm_exception_alert
|
||||||
|
from litellm.scheduler import FlowItem, Scheduler
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AsyncCursorPage,
|
|
||||||
Assistant,
|
Assistant,
|
||||||
Thread,
|
AssistantToolParam,
|
||||||
|
AsyncCursorPage,
|
||||||
Attachment,
|
Attachment,
|
||||||
OpenAIMessage,
|
OpenAIMessage,
|
||||||
Run,
|
Run,
|
||||||
AssistantToolParam,
|
Thread,
|
||||||
|
)
|
||||||
|
from litellm.types.router import (
|
||||||
|
AlertingConfig,
|
||||||
|
AllowedFailsPolicy,
|
||||||
|
AssistantsTypedDict,
|
||||||
|
Deployment,
|
||||||
|
DeploymentTypedDict,
|
||||||
|
LiteLLM_Params,
|
||||||
|
ModelGroupInfo,
|
||||||
|
ModelInfo,
|
||||||
|
RetryPolicy,
|
||||||
|
RouterErrors,
|
||||||
|
updateDeployment,
|
||||||
|
updateLiteLLMParams,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||||
|
from litellm.utils import (
|
||||||
|
CustomStreamWrapper,
|
||||||
|
ModelResponse,
|
||||||
|
_is_region_eu,
|
||||||
|
calculate_max_parallel_requests,
|
||||||
|
get_utc_datetime,
|
||||||
)
|
)
|
||||||
from litellm.scheduler import Scheduler, FlowItem
|
|
||||||
from typing import Iterable
|
|
||||||
from litellm.router_utils.handle_error import send_llm_exception_alert
|
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -3114,6 +3134,7 @@ class Router:
|
||||||
|
|
||||||
# proxy support
|
# proxy support
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
||||||
|
@ -3800,6 +3821,7 @@ class Router:
|
||||||
litellm_provider=llm_provider,
|
litellm_provider=llm_provider,
|
||||||
mode="chat",
|
mode="chat",
|
||||||
supported_openai_params=supported_openai_params,
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_group_info is None:
|
if model_group_info is None:
|
||||||
|
|
|
@ -3392,15 +3392,28 @@ def test_completion_deep_infra_mistral():
|
||||||
|
|
||||||
|
|
||||||
# Gemini tests
|
# Gemini tests
|
||||||
def test_completion_gemini():
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
# "gemini-1.0-pro",
|
||||||
|
"gemini-1.5-pro",
|
||||||
|
# "gemini-1.5-flash",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_gemini(model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "gemini/gemini-1.5-pro-latest"
|
model_name = "gemini/{}".format(model)
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
messages = [
|
||||||
|
{"role": "system", "content": "Be a good bot!"},
|
||||||
|
{"role": "user", "content": "Hey, how's it going?"},
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
response = completion(model=model_name, messages=messages)
|
response = completion(model=model_name, messages=messages)
|
||||||
# Add any assertions,here to check the response
|
# Add any assertions,here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
assert response.choices[0]["index"] == 0
|
assert response.choices[0]["index"] == 0
|
||||||
|
|
||||||
|
assert False
|
||||||
except litellm.APIError as e:
|
except litellm.APIError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,13 +1,22 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from typing_extensions import Dict, Required, TypedDict, override
|
from typing_extensions import Dict, Required, TypedDict, override
|
||||||
|
|
||||||
|
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||||
|
from typing_extensions import Dict, Required, TypedDict, override
|
||||||
|
|
||||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||||
|
|
||||||
|
@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False):
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
supported_openai_params: Required[Optional[List[str]]]
|
supported_openai_params: Required[Optional[List[str]]]
|
||||||
|
supports_system_messages: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
|
|
|
@ -1823,6 +1823,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given model supports function calling and return a boolean value.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The model name to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports function calling, False otherwise.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_info = litellm.get_model_info(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
if model_info.get("supports_system_messages", False) is True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
raise Exception(
|
||||||
|
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def supports_function_calling(model: str) -> bool:
|
def supports_function_calling(model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the given model supports function calling and return a boolean value.
|
Check if the given model supports function calling and return a boolean value.
|
||||||
|
@ -1838,7 +1864,7 @@ def supports_function_calling(model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
if model in litellm.model_cost:
|
if model in litellm.model_cost:
|
||||||
model_info = litellm.model_cost[model]
|
model_info = litellm.model_cost[model]
|
||||||
if model_info.get("supports_function_calling", False):
|
if model_info.get("supports_function_calling", False) is True:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -1862,7 +1888,7 @@ def supports_vision(model: str):
|
||||||
"""
|
"""
|
||||||
if model in litellm.model_cost:
|
if model in litellm.model_cost:
|
||||||
model_info = litellm.model_cost[model]
|
model_info = litellm.model_cost[model]
|
||||||
if model_info.get("supports_vision", False):
|
if model_info.get("supports_vision", False) is True:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -1884,7 +1910,7 @@ def supports_parallel_function_calling(model: str):
|
||||||
"""
|
"""
|
||||||
if model in litellm.model_cost:
|
if model in litellm.model_cost:
|
||||||
model_info = litellm.model_cost[model]
|
model_info = litellm.model_cost[model]
|
||||||
if model_info.get("supports_parallel_function_calling", False):
|
if model_info.get("supports_parallel_function_calling", False) is True:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -4319,14 +4345,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
)
|
)
|
||||||
if custom_llm_provider == "huggingface":
|
if custom_llm_provider == "huggingface":
|
||||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||||
return {
|
return ModelInfo(
|
||||||
"max_tokens": max_tokens, # type: ignore
|
max_tokens=max_tokens, # type: ignore
|
||||||
"input_cost_per_token": 0,
|
max_input_tokens=None,
|
||||||
"output_cost_per_token": 0,
|
max_output_tokens=None,
|
||||||
"litellm_provider": "huggingface",
|
input_cost_per_token=0,
|
||||||
"mode": "chat",
|
output_cost_per_token=0,
|
||||||
"supported_openai_params": supported_openai_params,
|
litellm_provider="huggingface",
|
||||||
}
|
mode="chat",
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
"""
|
"""
|
||||||
Check if: (in order of specificity)
|
Check if: (in order of specificity)
|
||||||
|
@ -4361,6 +4390,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
return ModelInfo(
|
||||||
|
max_tokens=_model_info.get("max_tokens", None),
|
||||||
|
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||||
|
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||||
|
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||||
|
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||||
|
litellm_provider=_model_info.get(
|
||||||
|
"litellm_provider", custom_llm_provider
|
||||||
|
),
|
||||||
|
mode=_model_info.get("mode"),
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=_model_info.get(
|
||||||
|
"supports_system_messages", None
|
||||||
|
),
|
||||||
|
)
|
||||||
return _model_info
|
return _model_info
|
||||||
elif split_model in litellm.model_cost:
|
elif split_model in litellm.model_cost:
|
||||||
_model_info = litellm.model_cost[split_model]
|
_model_info = litellm.model_cost[split_model]
|
||||||
|
@ -4375,7 +4419,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return _model_info
|
return ModelInfo(
|
||||||
|
max_tokens=_model_info.get("max_tokens", None),
|
||||||
|
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||||
|
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||||
|
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||||
|
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||||
|
litellm_provider=_model_info.get(
|
||||||
|
"litellm_provider", custom_llm_provider
|
||||||
|
),
|
||||||
|
mode=_model_info.get("mode"),
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=_model_info.get(
|
||||||
|
"supports_system_messages", None
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue