diff --git a/litellm/__init__.py b/litellm/__init__.py index 2c4845c6ec..6aee920c50 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -739,6 +739,7 @@ from .utils import ( supports_function_calling, supports_parallel_function_calling, supports_vision, + supports_system_messages, get_litellm_params, acreate, get_model_list, diff --git a/litellm/_logging.py b/litellm/_logging.py index c4d7c035a0..a98d85e1c4 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -12,7 +12,7 @@ if set_verbose is True: ) json_logs = bool(os.getenv("JSON_LOGS", False)) # Create a handler for the logger (you may need to adapt this based on your needs) -log_level = os.getenv("LITELLM_LOG", "ERROR") +log_level = os.getenv("LITELLM_LOG", "DEBUG") numeric_level: str = getattr(logging, log_level.upper()) handler = logging.StreamHandler() handler.setLevel(numeric_level) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index f3640c27c9..479e9bf3e2 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -18,6 +18,7 @@ import requests # type: ignore import litellm import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging +from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.vertex_ai import _gemini_convert_messages_with_history @@ -659,17 +660,29 @@ class VertexLLM(BaseLLM): ) ## TRANSFORMATION ## + try: + supports_system_message = litellm.supports_system_messages( + model=model, custom_llm_provider=custom_llm_provider + ) + except Exception as e: + verbose_logger.error( + "Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format( + str(e) + ) + ) + supports_system_message = False # Separate system prompt from rest of message system_prompt_indices = [] system_content_blocks: List[PartType] = [] - for idx, message in enumerate(messages): - if message["role"] == "system": - _system_content_block = PartType(text=message["content"]) - system_content_blocks.append(_system_content_block) - system_prompt_indices.append(idx) - if len(system_prompt_indices) > 0: - for idx in reversed(system_prompt_indices): - messages.pop(idx) + if supports_system_message is True: + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block = PartType(text=message["content"]) + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) content = _gemini_convert_messages_with_history(messages=messages) tools: Optional[Tools] = optional_params.pop("tools", None) tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) diff --git a/litellm/router.py b/litellm/router.py index cd6c9c16eb..db38df29f0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -7,66 +7,86 @@ # # Thank you ! We ❤️ you! - Krrish & Ishaan -import copy, httpx -from datetime import datetime -from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict -from typing_extensions import overload -import random, threading, time, traceback, uuid -import litellm, openai, hashlib, json -from litellm.caching import RedisCache, InMemoryCache, DualCache -import datetime as datetime_og -import logging, asyncio -import inspect, concurrent -from openai import AsyncOpenAI -from collections import defaultdict -from litellm.router_strategy.least_busy import LeastBusyLoggingHandler -from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler -from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler -from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler -from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 -from litellm.llms.custom_httpx.azure_dall_e_2 import ( - CustomHTTPTransport, - AsyncCustomHTTPTransport, -) -from litellm.utils import ( - ModelResponse, - CustomStreamWrapper, - get_utc_datetime, - calculate_max_parallel_requests, - _is_region_eu, -) +import asyncio +import concurrent import copy -from litellm._logging import verbose_router_logger +import datetime as datetime_og +import hashlib +import inspect +import json import logging -from litellm.types.utils import ModelInfo as ModelMapInfo -from litellm.types.router import ( - Deployment, - ModelInfo, - LiteLLM_Params, - RouterErrors, - updateDeployment, - updateLiteLLMParams, - RetryPolicy, - AllowedFailsPolicy, - AlertingConfig, - DeploymentTypedDict, - ModelGroupInfo, - AssistantsTypedDict, +import random +import threading +import time +import traceback +import uuid +from collections import defaultdict +from datetime import datetime +from typing import ( + Any, + BinaryIO, + Dict, + Iterable, + List, + Literal, + Optional, + Tuple, + TypedDict, + Union, ) + +import httpx +import openai +from openai import AsyncOpenAI +from typing_extensions import overload + +import litellm +from litellm._logging import verbose_router_logger +from litellm.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc +from litellm.llms.custom_httpx.azure_dall_e_2 import ( + AsyncCustomHTTPTransport, + CustomHTTPTransport, +) +from litellm.router_strategy.least_busy import LeastBusyLoggingHandler +from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler +from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler +from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler +from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 +from litellm.router_utils.handle_error import send_llm_exception_alert +from litellm.scheduler import FlowItem, Scheduler from litellm.types.llms.openai import ( - AsyncCursorPage, Assistant, - Thread, + AssistantToolParam, + AsyncCursorPage, Attachment, OpenAIMessage, Run, - AssistantToolParam, + Thread, +) +from litellm.types.router import ( + AlertingConfig, + AllowedFailsPolicy, + AssistantsTypedDict, + Deployment, + DeploymentTypedDict, + LiteLLM_Params, + ModelGroupInfo, + ModelInfo, + RetryPolicy, + RouterErrors, + updateDeployment, + updateLiteLLMParams, +) +from litellm.types.utils import ModelInfo as ModelMapInfo +from litellm.utils import ( + CustomStreamWrapper, + ModelResponse, + _is_region_eu, + calculate_max_parallel_requests, + get_utc_datetime, ) -from litellm.scheduler import Scheduler, FlowItem -from typing import Iterable -from litellm.router_utils.handle_error import send_llm_exception_alert class Router: @@ -3114,6 +3134,7 @@ class Router: # proxy support import os + import httpx # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. @@ -3800,6 +3821,7 @@ class Router: litellm_provider=llm_provider, mode="chat", supported_openai_params=supported_openai_params, + supports_system_messages=None, ) if model_group_info is None: diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index a60dd85070..f1ee63564c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -3392,15 +3392,28 @@ def test_completion_deep_infra_mistral(): # Gemini tests -def test_completion_gemini(): +@pytest.mark.parametrize( + "model", + [ + # "gemini-1.0-pro", + "gemini-1.5-pro", + # "gemini-1.5-flash", + ], +) +def test_completion_gemini(model): litellm.set_verbose = True - model_name = "gemini/gemini-1.5-pro-latest" - messages = [{"role": "user", "content": "Hey, how's it going?"}] + model_name = "gemini/{}".format(model) + messages = [ + {"role": "system", "content": "Be a good bot!"}, + {"role": "user", "content": "Hey, how's it going?"}, + ] try: response = completion(model=model_name, messages=messages) # Add any assertions,here to check the response print(response) assert response.choices[0]["index"] == 0 + + assert False except litellm.APIError as e: pass except Exception as e: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index b7c0e318e4..f021fcd345 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1,13 +1,22 @@ import json import time import uuid +import json +import time +import uuid from enum import Enum from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union + from openai._models import BaseModel as OpenAIObject from pydantic import ConfigDict from typing_extensions import Dict, Required, TypedDict, override +from ..litellm_core_utils.core_helpers import map_finish_reason +from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock +from typing_extensions import Dict, Required, TypedDict, override + from ..litellm_core_utils.core_helpers import map_finish_reason from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock @@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False): ] ] supported_openai_params: Required[Optional[List[str]]] + supports_system_messages: Optional[bool] class GenericStreamingChunk(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index 2a0b568918..8b640b16d3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1823,6 +1823,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool: return False +def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool: + """ + Check if the given model supports function calling and return a boolean value. + + Parameters: + model (str): The model name to be checked. + + Returns: + bool: True if the model supports function calling, False otherwise. + + Raises: + Exception: If the given model is not found in model_prices_and_context_window.json. + """ + try: + model_info = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + if model_info.get("supports_system_messages", False) is True: + return True + return False + except Exception: + raise Exception( + f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}." + ) + + def supports_function_calling(model: str) -> bool: """ Check if the given model supports function calling and return a boolean value. @@ -1838,7 +1864,7 @@ def supports_function_calling(model: str) -> bool: """ if model in litellm.model_cost: model_info = litellm.model_cost[model] - if model_info.get("supports_function_calling", False): + if model_info.get("supports_function_calling", False) is True: return True return False else: @@ -1862,7 +1888,7 @@ def supports_vision(model: str): """ if model in litellm.model_cost: model_info = litellm.model_cost[model] - if model_info.get("supports_vision", False): + if model_info.get("supports_vision", False) is True: return True return False else: @@ -1884,7 +1910,7 @@ def supports_parallel_function_calling(model: str): """ if model in litellm.model_cost: model_info = litellm.model_cost[model] - if model_info.get("supports_parallel_function_calling", False): + if model_info.get("supports_parallel_function_calling", False) is True: return True return False else: @@ -4319,14 +4345,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod ) if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) - return { - "max_tokens": max_tokens, # type: ignore - "input_cost_per_token": 0, - "output_cost_per_token": 0, - "litellm_provider": "huggingface", - "mode": "chat", - "supported_openai_params": supported_openai_params, - } + return ModelInfo( + max_tokens=max_tokens, # type: ignore + max_input_tokens=None, + max_output_tokens=None, + input_cost_per_token=0, + output_cost_per_token=0, + litellm_provider="huggingface", + mode="chat", + supported_openai_params=supported_openai_params, + supports_system_messages=None, + ) else: """ Check if: (in order of specificity) @@ -4361,6 +4390,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + ) return _model_info elif split_model in litellm.model_cost: _model_info = litellm.model_cost[split_model] @@ -4375,7 +4419,21 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return _model_info + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + ) else: raise ValueError( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"