diff --git a/.circleci/config.yml b/.circleci/config.yml index 6052fcba9..d070c55dc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -65,6 +65,7 @@ jobs: pip install "pydantic==2.7.1" pip install "diskcache==5.6.1" pip install "Pillow==10.3.0" + pip install "ijson==3.2.3" - save_cache: paths: - ./venv @@ -126,6 +127,7 @@ jobs: pip install jinja2 pip install tokenizers pip install openai + pip install ijson - run: name: Run tests command: | diff --git a/litellm/__init__.py b/litellm/__init__.py index dbdb2222d..a99ed20aa 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -13,7 +13,10 @@ from litellm._logging import ( verbose_logger, json_logs, _turn_on_json, + log_level, ) + + from litellm.proxy._types import ( KeyManagementSystem, KeyManagementSettings, @@ -736,6 +739,7 @@ from .utils import ( supports_function_calling, supports_parallel_function_calling, supports_vision, + supports_system_messages, get_litellm_params, acreate, get_model_list, diff --git a/litellm/_logging.py b/litellm/_logging.py index e3844bec2..b0935066b 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -12,7 +12,7 @@ if set_verbose is True: ) json_logs = bool(os.getenv("JSON_LOGS", False)) # Create a handler for the logger (you may need to adapt this based on your needs) -log_level = os.getenv("LITELLM_LOG", "ERROR") +log_level = os.getenv("LITELLM_LOG", "DEBUG") numeric_level: str = getattr(logging, log_level.upper()) handler = logging.StreamHandler() handler.setLevel(numeric_level) diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py index cfdf39eca..f48c4e29e 100644 --- a/litellm/llms/gemini.py +++ b/litellm/llms/gemini.py @@ -1,14 +1,22 @@ -import types -import traceback +#################################### +######### DEPRECATED FILE ########## +#################################### +# logic moved to `vertex_httpx.py` # + import copy import time +import traceback +import types from typing import Callable, Optional -from litellm.utils import ModelResponse, Choices, Message, Usage -import litellm + import httpx -from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt from packaging.version import Version + +import litellm from litellm import verbose_logger +from litellm.utils import Choices, Message, ModelResponse, Usage + +from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory class GeminiError(Exception): @@ -186,8 +194,8 @@ def completion( if _system_instruction and len(system_prompt) > 0: _params["system_instruction"] = system_prompt _model = genai.GenerativeModel(**_params) - if stream == True: - if acompletion == True: + if stream is True: + if acompletion is True: async def async_streaming(): try: diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index c9e48f3e1..479e9bf3e 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -1,41 +1,49 @@ # What is this? ## httpx client for vertex ai calls ## Initial implementation - covers gemini + image gen calls -from functools import partial -import os, types +import inspect import json -from enum import Enum -import requests # type: ignore +import os import time -from typing import Callable, Optional, Union, List, Any, Tuple +import types +import uuid +from enum import Enum +from functools import partial +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +import httpx # type: ignore +import ijson +import requests # type: ignore + +import litellm import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging -from litellm.utils import ModelResponse, Usage, CustomStreamWrapper +from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason -import litellm, uuid -import httpx, inspect # type: ignore from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from .base import BaseLLM -from litellm.types.llms.vertex_ai import ( - ContentType, - SystemInstructions, - PartType, - RequestBody, - GenerateContentResponseBody, - FunctionCallingConfig, - FunctionDeclaration, - Tools, - ToolConfig, - GenerationConfig, -) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history -from litellm.types.utils import GenericStreamingChunk from litellm.types.llms.openai import ( - ChatCompletionUsageBlock, + ChatCompletionResponseMessage, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, - ChatCompletionResponseMessage, + ChatCompletionUsageBlock, ) +from litellm.types.llms.vertex_ai import ( + ContentType, + FunctionCallingConfig, + FunctionDeclaration, + GenerateContentResponseBody, + GenerationConfig, + PartType, + RequestBody, + SystemInstructions, + ToolConfig, + Tools, +) +from litellm.types.utils import GenericStreamingChunk +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + +from .base import BaseLLM class VertexGeminiConfig: @@ -251,7 +259,7 @@ async def make_call( raise VertexAIError(status_code=response.status_code, message=response.text) completion_stream = ModelResponseIterator( - streaming_response=response.aiter_bytes(chunk_size=2056) + streaming_response=response.aiter_bytes(), sync_stream=False ) # LOGGING logging_obj.post_call( @@ -282,7 +290,7 @@ def make_sync_call( raise VertexAIError(status_code=response.status_code, message=response.read()) completion_stream = ModelResponseIterator( - streaming_response=response.iter_bytes(chunk_size=2056) + streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True ) # LOGGING @@ -414,9 +422,11 @@ class VertexLLM(BaseLLM): def load_auth( self, credentials: Optional[str], project_id: Optional[str] ) -> Tuple[Any, str]: - from google.auth.transport.requests import Request # type: ignore[import-untyped] - from google.auth.credentials import Credentials # type: ignore[import-untyped] import google.auth as google_auth + from google.auth.credentials import Credentials # type: ignore[import-untyped] + from google.auth.transport.requests import ( + Request, # type: ignore[import-untyped] + ) if credentials is not None and isinstance(credentials, str): import google.oauth2.service_account @@ -449,7 +459,9 @@ class VertexLLM(BaseLLM): return creds, project_id def refresh_auth(self, credentials: Any) -> None: - from google.auth.transport.requests import Request # type: ignore[import-untyped] + from google.auth.transport.requests import ( + Request, # type: ignore[import-untyped] + ) credentials.refresh(Request()) @@ -482,6 +494,50 @@ class VertexLLM(BaseLLM): return self._credentials.token, self.project_id + def _get_token_and_url( + self, + model: str, + gemini_api_key: Optional[str], + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], + stream: Optional[bool], + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + ) -> Tuple[Optional[str], str]: + """ + Internal function. Returns the token and url for the call. + + Handles logic if it's google ai studio vs. vertex ai. + + Returns + token, url + """ + if custom_llm_provider == "gemini": + _gemini_model_name = "models/{}".format(model) + auth_header = None + endpoint = "generateContent" + if stream is True: + endpoint = "streamGenerateContent" + + url = ( + "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + _gemini_model_name, endpoint, gemini_api_key + ) + ) + else: + auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + vertex_location = self.get_vertex_region(vertex_region=vertex_location) + + ### SET RUNTIME ENDPOINT ### + endpoint = "generateContent" + if stream is True: + endpoint = "streamGenerateContent" + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + + return auth_header, url + async def async_streaming( self, model: str, @@ -574,6 +630,9 @@ class VertexLLM(BaseLLM): messages: list, model_response: ModelResponse, print_verbose: Callable, + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) encoding, logging_obj, optional_params: dict, @@ -582,41 +641,58 @@ class VertexLLM(BaseLLM): vertex_project: Optional[str], vertex_location: Optional[str], vertex_credentials: Optional[str], + gemini_api_key: Optional[str], litellm_params=None, logger_fn=None, extra_headers: Optional[dict] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: - - auth_header, vertex_project = self._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project - ) - vertex_location = self.get_vertex_region(vertex_region=vertex_location) stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore - ### SET RUNTIME ENDPOINT ### - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent" + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=gemini_api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=stream, + custom_llm_provider=custom_llm_provider, + ) ## TRANSFORMATION ## + try: + supports_system_message = litellm.supports_system_messages( + model=model, custom_llm_provider=custom_llm_provider + ) + except Exception as e: + verbose_logger.error( + "Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format( + str(e) + ) + ) + supports_system_message = False # Separate system prompt from rest of message system_prompt_indices = [] system_content_blocks: List[PartType] = [] - for idx, message in enumerate(messages): - if message["role"] == "system": - _system_content_block = PartType(text=message["content"]) - system_content_blocks.append(_system_content_block) - system_prompt_indices.append(idx) - if len(system_prompt_indices) > 0: - for idx in reversed(system_prompt_indices): - messages.pop(idx) - system_instructions = SystemInstructions(parts=system_content_blocks) + if supports_system_message is True: + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block = PartType(text=message["content"]) + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) content = _gemini_convert_messages_with_history(messages=messages) tools: Optional[Tools] = optional_params.pop("tools", None) tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) generation_config: Optional[GenerationConfig] = GenerationConfig( **optional_params ) - data = RequestBody(system_instruction=system_instructions, contents=content) + data = RequestBody(contents=content) + if len(system_content_blocks) > 0: + system_instructions = SystemInstructions(parts=system_content_blocks) + data["system_instruction"] = system_instructions if tools is not None: data["tools"] = tools if tool_choice is not None: @@ -626,8 +702,9 @@ class VertexLLM(BaseLLM): headers = { "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {auth_header}", } + if auth_header is not None: + headers["Authorization"] = f"Bearer {auth_header}" ## LOGGING logging_obj.pre_call( @@ -642,6 +719,25 @@ class VertexLLM(BaseLLM): ### ROUTING (ASYNC, STREAMING, SYNC) if acompletion: + ### ASYNC STREAMING + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + data=json.dumps(data), # type: ignore + api_base=url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, # type: ignore + ) ### ASYNC COMPLETION return self.async_completion( model=model, @@ -853,9 +949,13 @@ class VertexLLM(BaseLLM): class ModelResponseIterator: - def __init__(self, streaming_response): + def __init__(self, streaming_response, sync_stream: bool): self.streaming_response = streaming_response - self.response_iterator = iter(self.streaming_response) + if sync_stream: + self.response_iterator = iter(self.streaming_response) + + self.events = ijson.sendable_list() + self.coro = ijson.items_coro(self.events, "item") def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: try: @@ -907,10 +1007,21 @@ class ModelResponseIterator: def __next__(self): try: - chunk = next(self.response_iterator) - chunk = chunk.decode() - json_chunk = json.loads(chunk) - return self.chunk_parser(chunk=json_chunk) + chunk = self.response_iterator.__next__() + self.coro.send(chunk) + if self.events: + event = self.events[0] + json_chunk = event + self.events.clear() + return self.chunk_parser(chunk=json_chunk) + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) except StopIteration: raise StopIteration except ValueError as e: @@ -924,9 +1035,20 @@ class ModelResponseIterator: async def __anext__(self): try: chunk = await self.async_response_iterator.__anext__() - chunk = chunk.decode() - json_chunk = json.loads(chunk) - return self.chunk_parser(chunk=json_chunk) + self.coro.send(chunk) + if self.events: + event = self.events[0] + json_chunk = event + self.events.clear() + return self.chunk_parser(chunk=json_chunk) + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) except StopAsyncIteration: raise StopAsyncIteration except ValueError as e: diff --git a/litellm/main.py b/litellm/main.py index 8c71b9a97..f46a9578b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1884,43 +1884,7 @@ def completion( ) return response response = model_response - elif custom_llm_provider == "gemini": - gemini_api_key = ( - api_key - or get_secret("GEMINI_API_KEY") - or get_secret("PALM_API_KEY") # older palm api key should also work - or litellm.api_key - ) - - # palm does not support streaming as yet :( - model_response = gemini.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=gemini_api_key, - logging_obj=logging, - acompletion=acompletion, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - iter(model_response), - model, - custom_llm_provider="gemini", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "vertex_ai_beta": + elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini": vertex_ai_project = ( optional_params.pop("vertex_project", None) or optional_params.pop("vertex_ai_project", None) @@ -1938,6 +1902,14 @@ def completion( or optional_params.pop("vertex_ai_credentials", None) or get_secret("VERTEXAI_CREDENTIALS") ) + + gemini_api_key = ( + api_key + or get_secret("GEMINI_API_KEY") + or get_secret("PALM_API_KEY") # older palm api key should also work + or litellm.api_key + ) + new_params = deepcopy(optional_params) response = vertex_chat_completion.completion( # type: ignore model=model, @@ -1951,9 +1923,11 @@ def completion( vertex_location=vertex_ai_location, vertex_project=vertex_ai_project, vertex_credentials=vertex_credentials, + gemini_api_key=gemini_api_key, logging_obj=logging, acompletion=acompletion, timeout=timeout, + custom_llm_provider=custom_llm_provider, ) elif custom_llm_provider == "vertex_ai": diff --git a/litellm/router.py b/litellm/router.py index cd6c9c16e..db38df29f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -7,66 +7,86 @@ # # Thank you ! We ❤️ you! - Krrish & Ishaan -import copy, httpx -from datetime import datetime -from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict -from typing_extensions import overload -import random, threading, time, traceback, uuid -import litellm, openai, hashlib, json -from litellm.caching import RedisCache, InMemoryCache, DualCache -import datetime as datetime_og -import logging, asyncio -import inspect, concurrent -from openai import AsyncOpenAI -from collections import defaultdict -from litellm.router_strategy.least_busy import LeastBusyLoggingHandler -from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler -from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler -from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler -from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 -from litellm.llms.custom_httpx.azure_dall_e_2 import ( - CustomHTTPTransport, - AsyncCustomHTTPTransport, -) -from litellm.utils import ( - ModelResponse, - CustomStreamWrapper, - get_utc_datetime, - calculate_max_parallel_requests, - _is_region_eu, -) +import asyncio +import concurrent import copy -from litellm._logging import verbose_router_logger +import datetime as datetime_og +import hashlib +import inspect +import json import logging -from litellm.types.utils import ModelInfo as ModelMapInfo -from litellm.types.router import ( - Deployment, - ModelInfo, - LiteLLM_Params, - RouterErrors, - updateDeployment, - updateLiteLLMParams, - RetryPolicy, - AllowedFailsPolicy, - AlertingConfig, - DeploymentTypedDict, - ModelGroupInfo, - AssistantsTypedDict, +import random +import threading +import time +import traceback +import uuid +from collections import defaultdict +from datetime import datetime +from typing import ( + Any, + BinaryIO, + Dict, + Iterable, + List, + Literal, + Optional, + Tuple, + TypedDict, + Union, ) + +import httpx +import openai +from openai import AsyncOpenAI +from typing_extensions import overload + +import litellm +from litellm._logging import verbose_router_logger +from litellm.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc +from litellm.llms.custom_httpx.azure_dall_e_2 import ( + AsyncCustomHTTPTransport, + CustomHTTPTransport, +) +from litellm.router_strategy.least_busy import LeastBusyLoggingHandler +from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler +from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler +from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler +from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 +from litellm.router_utils.handle_error import send_llm_exception_alert +from litellm.scheduler import FlowItem, Scheduler from litellm.types.llms.openai import ( - AsyncCursorPage, Assistant, - Thread, + AssistantToolParam, + AsyncCursorPage, Attachment, OpenAIMessage, Run, - AssistantToolParam, + Thread, +) +from litellm.types.router import ( + AlertingConfig, + AllowedFailsPolicy, + AssistantsTypedDict, + Deployment, + DeploymentTypedDict, + LiteLLM_Params, + ModelGroupInfo, + ModelInfo, + RetryPolicy, + RouterErrors, + updateDeployment, + updateLiteLLMParams, +) +from litellm.types.utils import ModelInfo as ModelMapInfo +from litellm.utils import ( + CustomStreamWrapper, + ModelResponse, + _is_region_eu, + calculate_max_parallel_requests, + get_utc_datetime, ) -from litellm.scheduler import Scheduler, FlowItem -from typing import Iterable -from litellm.router_utils.handle_error import send_llm_exception_alert class Router: @@ -3114,6 +3134,7 @@ class Router: # proxy support import os + import httpx # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. @@ -3800,6 +3821,7 @@ class Router: litellm_provider=llm_provider, mode="chat", supported_openai_params=supported_openai_params, + supports_system_messages=None, ) if model_group_info is None: diff --git a/litellm/tests/conftest.py b/litellm/tests/conftest.py index 8c2ce781f..244ea0754 100644 --- a/litellm/tests/conftest.py +++ b/litellm/tests/conftest.py @@ -1,7 +1,10 @@ # conftest.py -import pytest, sys, os import importlib +import os +import sys + +import pytest sys.path.insert( 0, os.path.abspath("../..") @@ -18,6 +21,7 @@ def setup_and_teardown(): sys.path.insert( 0, os.path.abspath("../..") ) # Adds the project directory to the system path + print("LITELLM_LOG - {}".format(os.getenv("LITELLM_LOG"))) import litellm from litellm import Router diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0cb1f7929..a94577032 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -12,6 +12,8 @@ import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path + +import os from unittest.mock import MagicMock, patch import pytest @@ -3335,6 +3337,7 @@ def test_mistral_anyscale_stream(): #### Test A121 ################### +@pytest.mark.skip(reason="Local test") def test_completion_ai21(): print("running ai21 j2light test") litellm.set_verbose = True @@ -3390,10 +3393,21 @@ def test_completion_deep_infra_mistral(): # Gemini tests -def test_completion_gemini(): +@pytest.mark.parametrize( + "model", + [ + # "gemini-1.0-pro", + "gemini-1.5-pro", + # "gemini-1.5-flash", + ], +) +def test_completion_gemini(model): litellm.set_verbose = True - model_name = "gemini/gemini-1.5-pro-latest" - messages = [{"role": "user", "content": "Hey, how's it going?"}] + model_name = "gemini/{}".format(model) + messages = [ + {"role": "system", "content": "Be a good bot!"}, + {"role": "user", "content": "Hey, how's it going?"}, + ] try: response = completion(model=model_name, messages=messages) # Add any assertions,here to check the response @@ -3485,7 +3499,7 @@ def test_completion_palm_stream(): pytest.fail(f"Error occurred: {e}") -@pytest.mark.skip(reason="IBM closed account.") +@pytest.mark.skip(reason="Account deleted by IBM.") def test_completion_watsonx(): litellm.set_verbose = True model_name = "watsonx/ibm/granite-13b-chat-v2" @@ -3506,7 +3520,7 @@ def test_completion_watsonx(): pytest.fail(f"Error occurred: {e}") -@pytest.mark.skip(reason="IBM closed account.") +@pytest.mark.skip(reason="Skip test. account deleted.") def test_completion_stream_watsonx(): litellm.set_verbose = True model_name = "watsonx/ibm/granite-13b-chat-v2" @@ -3574,7 +3588,7 @@ def test_unified_auth_params(provider, model, project, region_name, token): assert value in translated_optional_params -@pytest.mark.skip(reason="IBM closed account.") +@pytest.mark.skip(reason="Local test") @pytest.mark.asyncio async def test_acompletion_watsonx(): litellm.set_verbose = True @@ -3595,7 +3609,7 @@ async def test_acompletion_watsonx(): pytest.fail(f"Error occurred: {e}") -@pytest.mark.skip(reason="IBM closed account.") +@pytest.mark.skip(reason="Local test") @pytest.mark.asyncio async def test_acompletion_stream_watsonx(): litellm.set_verbose = True diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 1c402faa5..ecb21b9f2 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1776,7 +1776,7 @@ def test_completion_sagemaker_stream(): pytest.fail(f"Error occurred: {e}") -@pytest.mark.skip(reason="IBM closed account.") +@pytest.mark.skip(reason="Account deleted by IBM.") def test_completion_watsonx_stream(): litellm.set_verbose = True try: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index b7c0e318e..f021fcd34 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1,13 +1,22 @@ import json import time import uuid +import json +import time +import uuid from enum import Enum from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union + from openai._models import BaseModel as OpenAIObject from pydantic import ConfigDict from typing_extensions import Dict, Required, TypedDict, override +from ..litellm_core_utils.core_helpers import map_finish_reason +from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock +from typing_extensions import Dict, Required, TypedDict, override + from ..litellm_core_utils.core_helpers import map_finish_reason from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock @@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False): ] ] supported_openai_params: Required[Optional[List[str]]] + supports_system_messages: Optional[bool] class GenericStreamingChunk(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index 2c953b93e..882185ba8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1824,6 +1824,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool: return False +def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool: + """ + Check if the given model supports function calling and return a boolean value. + + Parameters: + model (str): The model name to be checked. + + Returns: + bool: True if the model supports function calling, False otherwise. + + Raises: + Exception: If the given model is not found in model_prices_and_context_window.json. + """ + try: + model_info = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + if model_info.get("supports_system_messages", False) is True: + return True + return False + except Exception: + raise Exception( + f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}." + ) + + def supports_function_calling(model: str) -> bool: """ Check if the given model supports function calling and return a boolean value. @@ -1839,7 +1865,7 @@ def supports_function_calling(model: str) -> bool: """ if model in litellm.model_cost: model_info = litellm.model_cost[model] - if model_info.get("supports_function_calling", False): + if model_info.get("supports_function_calling", False) is True: return True return False else: @@ -1863,7 +1889,7 @@ def supports_vision(model: str): """ if model in litellm.model_cost: model_info = litellm.model_cost[model] - if model_info.get("supports_vision", False): + if model_info.get("supports_vision", False) is True: return True return False else: @@ -1885,7 +1911,7 @@ def supports_parallel_function_calling(model: str): """ if model in litellm.model_cost: model_info = litellm.model_cost[model] - if model_info.get("supports_parallel_function_calling", False): + if model_info.get("supports_parallel_function_calling", False) is True: return True return False else: @@ -4320,14 +4346,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod ) if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) - return { - "max_tokens": max_tokens, # type: ignore - "input_cost_per_token": 0, - "output_cost_per_token": 0, - "litellm_provider": "huggingface", - "mode": "chat", - "supported_openai_params": supported_openai_params, - } + return ModelInfo( + max_tokens=max_tokens, # type: ignore + max_input_tokens=None, + max_output_tokens=None, + input_cost_per_token=0, + output_cost_per_token=0, + litellm_provider="huggingface", + mode="chat", + supported_openai_params=supported_openai_params, + supports_system_messages=None, + ) else: """ Check if: (in order of specificity) @@ -4348,7 +4377,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return _model_info + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + ) elif model in litellm.model_cost: _model_info = litellm.model_cost[model] _model_info["supported_openai_params"] = supported_openai_params @@ -4362,7 +4411,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return _model_info + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + ) elif split_model in litellm.model_cost: _model_info = litellm.model_cost[split_model] _model_info["supported_openai_params"] = supported_openai_params @@ -4376,7 +4445,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return _model_info + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + ) else: raise ValueError( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" @@ -7030,7 +7119,9 @@ def exception_type( exception_mapping_worked = True if hasattr(original_exception, "request"): raise APIConnectionError( - message=f"{str(original_exception)}", + message="{}\n{}".format( + str(original_exception), traceback.format_exc() + ), llm_provider=custom_llm_provider, model=model, request=original_exception.request, diff --git a/pyproject.toml b/pyproject.toml index 08a2a549d..9a9304165 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ jinja2 = "^3.1.2" aiohttp = "*" requests = "^2.31.0" pydantic = "^2.0.0" +ijson = "*" uvicorn = {version = "^0.22.0", optional = true} gunicorn = {version = "^22.0.0", optional = true} diff --git a/requirements.txt b/requirements.txt index ab755fec3..fbf2bfc1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,4 +44,5 @@ aiohttp==3.9.0 # for network calls aioboto3==12.3.0 # for async sagemaker calls tenacity==8.2.3 # for retrying requests, when litellm.num_retries set pydantic==2.7.1 # proxy + openai req. +ijson==3.2.3 # for google ai studio streaming #### \ No newline at end of file