Merge pull request #4246 from BerriAI/litellm_gemini_refactoring

feat(main.py): Gemini (Google AI Studio) - Support Function Calling, Inline images, etc.
This commit is contained in:
Krish Dholakia 2024-06-17 19:51:12 -07:00 committed by GitHub
commit af2917d655
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 427 additions and 174 deletions

View file

@ -65,6 +65,7 @@ jobs:
pip install "pydantic==2.7.1" pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1" pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0" pip install "Pillow==10.3.0"
pip install "ijson==3.2.3"
- save_cache: - save_cache:
paths: paths:
- ./venv - ./venv
@ -126,6 +127,7 @@ jobs:
pip install jinja2 pip install jinja2
pip install tokenizers pip install tokenizers
pip install openai pip install openai
pip install ijson
- run: - run:
name: Run tests name: Run tests
command: | command: |

View file

@ -13,7 +13,10 @@ from litellm._logging import (
verbose_logger, verbose_logger,
json_logs, json_logs,
_turn_on_json, _turn_on_json,
log_level,
) )
from litellm.proxy._types import ( from litellm.proxy._types import (
KeyManagementSystem, KeyManagementSystem,
KeyManagementSettings, KeyManagementSettings,
@ -736,6 +739,7 @@ from .utils import (
supports_function_calling, supports_function_calling,
supports_parallel_function_calling, supports_parallel_function_calling,
supports_vision, supports_vision,
supports_system_messages,
get_litellm_params, get_litellm_params,
acreate, acreate,
get_model_list, get_model_list,

View file

@ -12,7 +12,7 @@ if set_verbose is True:
) )
json_logs = bool(os.getenv("JSON_LOGS", False)) json_logs = bool(os.getenv("JSON_LOGS", False))
# Create a handler for the logger (you may need to adapt this based on your needs) # Create a handler for the logger (you may need to adapt this based on your needs)
log_level = os.getenv("LITELLM_LOG", "ERROR") log_level = os.getenv("LITELLM_LOG", "DEBUG")
numeric_level: str = getattr(logging, log_level.upper()) numeric_level: str = getattr(logging, log_level.upper())
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setLevel(numeric_level) handler.setLevel(numeric_level)

View file

@ -1,14 +1,22 @@
import types ####################################
import traceback ######### DEPRECATED FILE ##########
####################################
# logic moved to `vertex_httpx.py` #
import copy import copy
import time import time
import traceback
import types
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version from packaging.version import Version
import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.utils import Choices, Message, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
class GeminiError(Exception): class GeminiError(Exception):
@ -186,8 +194,8 @@ def completion(
if _system_instruction and len(system_prompt) > 0: if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt _params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params) _model = genai.GenerativeModel(**_params)
if stream == True: if stream is True:
if acompletion == True: if acompletion is True:
async def async_streaming(): async def async_streaming():
try: try:

View file

@ -1,41 +1,49 @@
# What is this? # What is this?
## httpx client for vertex ai calls ## httpx client for vertex ai calls
## Initial implementation - covers gemini + image gen calls ## Initial implementation - covers gemini + image gen calls
from functools import partial import inspect
import os, types
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List, Any, Tuple import types
import uuid
from enum import Enum
from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import ijson
import requests # type: ignore
import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging import litellm.litellm_core_utils.litellm_logging
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
from litellm.types.llms.vertex_ai import (
ContentType,
SystemInstructions,
PartType,
RequestBody,
GenerateContentResponseBody,
FunctionCallingConfig,
FunctionDeclaration,
Tools,
ToolConfig,
GenerationConfig,
)
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
from litellm.types.utils import GenericStreamingChunk
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionUsageBlock, ChatCompletionResponseMessage,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionResponseMessage, ChatCompletionUsageBlock,
) )
from litellm.types.llms.vertex_ai import (
ContentType,
FunctionCallingConfig,
FunctionDeclaration,
GenerateContentResponseBody,
GenerationConfig,
PartType,
RequestBody,
SystemInstructions,
ToolConfig,
Tools,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
class VertexGeminiConfig: class VertexGeminiConfig:
@ -251,7 +259,7 @@ async def make_call(
raise VertexAIError(status_code=response.status_code, message=response.text) raise VertexAIError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.aiter_bytes(chunk_size=2056) streaming_response=response.aiter_bytes(), sync_stream=False
) )
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -282,7 +290,7 @@ def make_sync_call(
raise VertexAIError(status_code=response.status_code, message=response.read()) raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056) streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
) )
# LOGGING # LOGGING
@ -414,9 +422,11 @@ class VertexLLM(BaseLLM):
def load_auth( def load_auth(
self, credentials: Optional[str], project_id: Optional[str] self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[Any, str]: ) -> Tuple[Any, str]:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.credentials import Credentials # type: ignore[import-untyped]
import google.auth as google_auth import google.auth as google_auth
from google.auth.credentials import Credentials # type: ignore[import-untyped]
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
if credentials is not None and isinstance(credentials, str): if credentials is not None and isinstance(credentials, str):
import google.oauth2.service_account import google.oauth2.service_account
@ -449,7 +459,9 @@ class VertexLLM(BaseLLM):
return creds, project_id return creds, project_id
def refresh_auth(self, credentials: Any) -> None: def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import Request # type: ignore[import-untyped] from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
credentials.refresh(Request()) credentials.refresh(Request())
@ -482,6 +494,50 @@ class VertexLLM(BaseLLM):
return self._credentials.token, self.project_id return self._credentials.token, self.project_id
def _get_token_and_url(
self,
model: str,
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
)
else:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
return auth_header, url
async def async_streaming( async def async_streaming(
self, self,
model: str, model: str,
@ -574,6 +630,9 @@ class VertexLLM(BaseLLM):
messages: list, messages: list,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
@ -582,41 +641,58 @@ class VertexLLM(BaseLLM):
vertex_project: Optional[str], vertex_project: Optional[str],
vertex_location: Optional[str], vertex_location: Optional[str],
vertex_credentials: Optional[str], vertex_credentials: Optional[str],
gemini_api_key: Optional[str],
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
### SET RUNTIME ENDPOINT ### auth_header, url = self._get_token_and_url(
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent" model=model,
gemini_api_key=gemini_api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=stream,
custom_llm_provider=custom_llm_provider,
)
## TRANSFORMATION ## ## TRANSFORMATION ##
try:
supports_system_message = litellm.supports_system_messages(
model=model, custom_llm_provider=custom_llm_provider
)
except Exception as e:
verbose_logger.error(
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
str(e)
)
)
supports_system_message = False
# Separate system prompt from rest of message # Separate system prompt from rest of message
system_prompt_indices = [] system_prompt_indices = []
system_content_blocks: List[PartType] = [] system_content_blocks: List[PartType] = []
for idx, message in enumerate(messages): if supports_system_message is True:
if message["role"] == "system": for idx, message in enumerate(messages):
_system_content_block = PartType(text=message["content"]) if message["role"] == "system":
system_content_blocks.append(_system_content_block) _system_content_block = PartType(text=message["content"])
system_prompt_indices.append(idx) system_content_blocks.append(_system_content_block)
if len(system_prompt_indices) > 0: system_prompt_indices.append(idx)
for idx in reversed(system_prompt_indices): if len(system_prompt_indices) > 0:
messages.pop(idx) for idx in reversed(system_prompt_indices):
system_instructions = SystemInstructions(parts=system_content_blocks) messages.pop(idx)
content = _gemini_convert_messages_with_history(messages=messages) content = _gemini_convert_messages_with_history(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None) tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
generation_config: Optional[GenerationConfig] = GenerationConfig( generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params **optional_params
) )
data = RequestBody(system_instruction=system_instructions, contents=content) data = RequestBody(contents=content)
if len(system_content_blocks) > 0:
system_instructions = SystemInstructions(parts=system_content_blocks)
data["system_instruction"] = system_instructions
if tools is not None: if tools is not None:
data["tools"] = tools data["tools"] = tools
if tool_choice is not None: if tool_choice is not None:
@ -626,8 +702,9 @@ class VertexLLM(BaseLLM):
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
} }
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -642,6 +719,25 @@ class VertexLLM(BaseLLM):
### ROUTING (ASYNC, STREAMING, SYNC) ### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion: if acompletion:
### ASYNC STREAMING
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=json.dumps(data), # type: ignore
api_base=url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client, # type: ignore
)
### ASYNC COMPLETION ### ASYNC COMPLETION
return self.async_completion( return self.async_completion(
model=model, model=model,
@ -853,9 +949,13 @@ class VertexLLM(BaseLLM):
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, streaming_response): def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response self.streaming_response = streaming_response
self.response_iterator = iter(self.streaming_response) if sync_stream:
self.response_iterator = iter(self.streaming_response)
self.events = ijson.sendable_list()
self.coro = ijson.items_coro(self.events, "item")
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try: try:
@ -907,10 +1007,21 @@ class ModelResponseIterator:
def __next__(self): def __next__(self):
try: try:
chunk = next(self.response_iterator) chunk = self.response_iterator.__next__()
chunk = chunk.decode() self.coro.send(chunk)
json_chunk = json.loads(chunk) if self.events:
return self.chunk_parser(chunk=json_chunk) event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration: except StopIteration:
raise StopIteration raise StopIteration
except ValueError as e: except ValueError as e:
@ -924,9 +1035,20 @@ class ModelResponseIterator:
async def __anext__(self): async def __anext__(self):
try: try:
chunk = await self.async_response_iterator.__anext__() chunk = await self.async_response_iterator.__anext__()
chunk = chunk.decode() self.coro.send(chunk)
json_chunk = json.loads(chunk) if self.events:
return self.chunk_parser(chunk=json_chunk) event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration: except StopAsyncIteration:
raise StopAsyncIteration raise StopAsyncIteration
except ValueError as e: except ValueError as e:

View file

@ -1884,43 +1884,7 @@ def completion(
) )
return response return response
response = model_response response = model_response
elif custom_llm_provider == "gemini": elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
or get_secret("PALM_API_KEY") # older palm api key should also work
or litellm.api_key
)
# palm does not support streaming as yet :(
model_response = gemini.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=gemini_api_key,
logging_obj=logging,
acompletion=acompletion,
custom_prompt_dict=custom_prompt_dict,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
iter(model_response),
model,
custom_llm_provider="gemini",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai_beta":
vertex_ai_project = ( vertex_ai_project = (
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None) or optional_params.pop("vertex_ai_project", None)
@ -1938,6 +1902,14 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None) or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS") or get_secret("VERTEXAI_CREDENTIALS")
) )
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
or get_secret("PALM_API_KEY") # older palm api key should also work
or litellm.api_key
)
new_params = deepcopy(optional_params) new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore response = vertex_chat_completion.completion( # type: ignore
model=model, model=model,
@ -1951,9 +1923,11 @@ def completion(
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project, vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
gemini_api_key=gemini_api_key,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
timeout=timeout, timeout=timeout,
custom_llm_provider=custom_llm_provider,
) )
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":

View file

@ -7,66 +7,86 @@
# #
# Thank you ! We ❤️ you! - Krrish & Ishaan # Thank you ! We ❤️ you! - Krrish & Ishaan
import copy, httpx import asyncio
from datetime import datetime import concurrent
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache
import datetime as datetime_og
import logging, asyncio
import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport,
AsyncCustomHTTPTransport,
)
from litellm.utils import (
ModelResponse,
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
_is_region_eu,
)
import copy import copy
from litellm._logging import verbose_router_logger import datetime as datetime_og
import hashlib
import inspect
import json
import logging import logging
from litellm.types.utils import ModelInfo as ModelMapInfo import random
from litellm.types.router import ( import threading
Deployment, import time
ModelInfo, import traceback
LiteLLM_Params, import uuid
RouterErrors, from collections import defaultdict
updateDeployment, from datetime import datetime
updateLiteLLMParams, from typing import (
RetryPolicy, Any,
AllowedFailsPolicy, BinaryIO,
AlertingConfig, Dict,
DeploymentTypedDict, Iterable,
ModelGroupInfo, List,
AssistantsTypedDict, Literal,
Optional,
Tuple,
TypedDict,
Union,
) )
import httpx
import openai
from openai import AsyncOpenAI
from typing_extensions import overload
import litellm
from litellm._logging import verbose_router_logger
from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.custom_httpx.azure_dall_e_2 import (
AsyncCustomHTTPTransport,
CustomHTTPTransport,
)
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_utils.handle_error import send_llm_exception_alert
from litellm.scheduler import FlowItem, Scheduler
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AsyncCursorPage,
Assistant, Assistant,
Thread, AssistantToolParam,
AsyncCursorPage,
Attachment, Attachment,
OpenAIMessage, OpenAIMessage,
Run, Run,
AssistantToolParam, Thread,
)
from litellm.types.router import (
AlertingConfig,
AllowedFailsPolicy,
AssistantsTypedDict,
Deployment,
DeploymentTypedDict,
LiteLLM_Params,
ModelGroupInfo,
ModelInfo,
RetryPolicy,
RouterErrors,
updateDeployment,
updateLiteLLMParams,
)
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.utils import (
CustomStreamWrapper,
ModelResponse,
_is_region_eu,
calculate_max_parallel_requests,
get_utc_datetime,
) )
from litellm.scheduler import Scheduler, FlowItem
from typing import Iterable
from litellm.router_utils.handle_error import send_llm_exception_alert
class Router: class Router:
@ -3114,6 +3134,7 @@ class Router:
# proxy support # proxy support
import os import os
import httpx import httpx
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
@ -3800,6 +3821,7 @@ class Router:
litellm_provider=llm_provider, litellm_provider=llm_provider,
mode="chat", mode="chat",
supported_openai_params=supported_openai_params, supported_openai_params=supported_openai_params,
supports_system_messages=None,
) )
if model_group_info is None: if model_group_info is None:

View file

@ -1,7 +1,10 @@
# conftest.py # conftest.py
import pytest, sys, os
import importlib import importlib
import os
import sys
import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -18,6 +21,7 @@ def setup_and_teardown():
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the project directory to the system path ) # Adds the project directory to the system path
print("LITELLM_LOG - {}".format(os.getenv("LITELLM_LOG")))
import litellm import litellm
from litellm import Router from litellm import Router

View file

@ -12,6 +12,8 @@ import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import os
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -3335,6 +3337,7 @@ def test_mistral_anyscale_stream():
#### Test A121 ################### #### Test A121 ###################
@pytest.mark.skip(reason="Local test")
def test_completion_ai21(): def test_completion_ai21():
print("running ai21 j2light test") print("running ai21 j2light test")
litellm.set_verbose = True litellm.set_verbose = True
@ -3390,10 +3393,21 @@ def test_completion_deep_infra_mistral():
# Gemini tests # Gemini tests
def test_completion_gemini(): @pytest.mark.parametrize(
"model",
[
# "gemini-1.0-pro",
"gemini-1.5-pro",
# "gemini-1.5-flash",
],
)
def test_completion_gemini(model):
litellm.set_verbose = True litellm.set_verbose = True
model_name = "gemini/gemini-1.5-pro-latest" model_name = "gemini/{}".format(model)
messages = [{"role": "user", "content": "Hey, how's it going?"}] messages = [
{"role": "system", "content": "Be a good bot!"},
{"role": "user", "content": "Hey, how's it going?"},
]
try: try:
response = completion(model=model_name, messages=messages) response = completion(model=model_name, messages=messages)
# Add any assertions,here to check the response # Add any assertions,here to check the response
@ -3485,7 +3499,7 @@ def test_completion_palm_stream():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="IBM closed account.") @pytest.mark.skip(reason="Account deleted by IBM.")
def test_completion_watsonx(): def test_completion_watsonx():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2" model_name = "watsonx/ibm/granite-13b-chat-v2"
@ -3506,7 +3520,7 @@ def test_completion_watsonx():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="IBM closed account.") @pytest.mark.skip(reason="Skip test. account deleted.")
def test_completion_stream_watsonx(): def test_completion_stream_watsonx():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2" model_name = "watsonx/ibm/granite-13b-chat-v2"
@ -3574,7 +3588,7 @@ def test_unified_auth_params(provider, model, project, region_name, token):
assert value in translated_optional_params assert value in translated_optional_params
@pytest.mark.skip(reason="IBM closed account.") @pytest.mark.skip(reason="Local test")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_watsonx(): async def test_acompletion_watsonx():
litellm.set_verbose = True litellm.set_verbose = True
@ -3595,7 +3609,7 @@ async def test_acompletion_watsonx():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="IBM closed account.") @pytest.mark.skip(reason="Local test")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_stream_watsonx(): async def test_acompletion_stream_watsonx():
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -1776,7 +1776,7 @@ def test_completion_sagemaker_stream():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="IBM closed account.") @pytest.mark.skip(reason="Account deleted by IBM.")
def test_completion_watsonx_stream(): def test_completion_watsonx_stream():
litellm.set_verbose = True litellm.set_verbose = True
try: try:

View file

@ -1,13 +1,22 @@
import json import json
import time import time
import uuid import uuid
import json
import time
import uuid
from enum import Enum from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Union from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict from pydantic import ConfigDict
from typing_extensions import Dict, Required, TypedDict, override from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False):
] ]
] ]
supported_openai_params: Required[Optional[List[str]]] supported_openai_params: Required[Optional[List[str]]]
supports_system_messages: Optional[bool]
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):

View file

@ -1824,6 +1824,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
return False return False
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
Parameters:
model (str): The model name to be checked.
Returns:
bool: True if the model supports function calling, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
"""
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("supports_system_messages", False) is True:
return True
return False
except Exception:
raise Exception(
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
)
def supports_function_calling(model: str) -> bool: def supports_function_calling(model: str) -> bool:
""" """
Check if the given model supports function calling and return a boolean value. Check if the given model supports function calling and return a boolean value.
@ -1839,7 +1865,7 @@ def supports_function_calling(model: str) -> bool:
""" """
if model in litellm.model_cost: if model in litellm.model_cost:
model_info = litellm.model_cost[model] model_info = litellm.model_cost[model]
if model_info.get("supports_function_calling", False): if model_info.get("supports_function_calling", False) is True:
return True return True
return False return False
else: else:
@ -1863,7 +1889,7 @@ def supports_vision(model: str):
""" """
if model in litellm.model_cost: if model in litellm.model_cost:
model_info = litellm.model_cost[model] model_info = litellm.model_cost[model]
if model_info.get("supports_vision", False): if model_info.get("supports_vision", False) is True:
return True return True
return False return False
else: else:
@ -1885,7 +1911,7 @@ def supports_parallel_function_calling(model: str):
""" """
if model in litellm.model_cost: if model in litellm.model_cost:
model_info = litellm.model_cost[model] model_info = litellm.model_cost[model]
if model_info.get("supports_parallel_function_calling", False): if model_info.get("supports_parallel_function_calling", False) is True:
return True return True
return False return False
else: else:
@ -4320,14 +4346,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
) )
if custom_llm_provider == "huggingface": if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model) max_tokens = _get_max_position_embeddings(model_name=model)
return { return ModelInfo(
"max_tokens": max_tokens, # type: ignore max_tokens=max_tokens, # type: ignore
"input_cost_per_token": 0, max_input_tokens=None,
"output_cost_per_token": 0, max_output_tokens=None,
"litellm_provider": "huggingface", input_cost_per_token=0,
"mode": "chat", output_cost_per_token=0,
"supported_openai_params": supported_openai_params, litellm_provider="huggingface",
} mode="chat",
supported_openai_params=supported_openai_params,
supports_system_messages=None,
)
else: else:
""" """
Check if: (in order of specificity) Check if: (in order of specificity)
@ -4348,7 +4377,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass pass
else: else:
raise Exception raise Exception
return _model_info return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
elif model in litellm.model_cost: elif model in litellm.model_cost:
_model_info = litellm.model_cost[model] _model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params _model_info["supported_openai_params"] = supported_openai_params
@ -4362,7 +4411,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass pass
else: else:
raise Exception raise Exception
return _model_info return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
elif split_model in litellm.model_cost: elif split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model] _model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params _model_info["supported_openai_params"] = supported_openai_params
@ -4376,7 +4445,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass pass
else: else:
raise Exception raise Exception
return _model_info return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
else: else:
raise ValueError( raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
@ -7030,7 +7119,9 @@ def exception_type(
exception_mapping_worked = True exception_mapping_worked = True
if hasattr(original_exception, "request"): if hasattr(original_exception, "request"):
raise APIConnectionError( raise APIConnectionError(
message=f"{str(original_exception)}", message="{}\n{}".format(
str(original_exception), traceback.format_exc()
),
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
request=original_exception.request, request=original_exception.request,

View file

@ -27,6 +27,7 @@ jinja2 = "^3.1.2"
aiohttp = "*" aiohttp = "*"
requests = "^2.31.0" requests = "^2.31.0"
pydantic = "^2.0.0" pydantic = "^2.0.0"
ijson = "*"
uvicorn = {version = "^0.22.0", optional = true} uvicorn = {version = "^0.22.0", optional = true}
gunicorn = {version = "^22.0.0", optional = true} gunicorn = {version = "^22.0.0", optional = true}

View file

@ -44,4 +44,5 @@ aiohttp==3.9.0 # for network calls
aioboto3==12.3.0 # for async sagemaker calls aioboto3==12.3.0 # for async sagemaker calls
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
pydantic==2.7.1 # proxy + openai req. pydantic==2.7.1 # proxy + openai req.
ijson==3.2.3 # for google ai studio streaming
#### ####