Merge pull request #4246 from BerriAI/litellm_gemini_refactoring

feat(main.py): Gemini (Google AI Studio) - Support Function Calling, Inline images, etc.
This commit is contained in:
Krish Dholakia 2024-06-17 19:51:12 -07:00 committed by GitHub
commit af2917d655
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 427 additions and 174 deletions

View file

@ -65,6 +65,7 @@ jobs:
pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0"
pip install "ijson==3.2.3"
- save_cache:
paths:
- ./venv
@ -126,6 +127,7 @@ jobs:
pip install jinja2
pip install tokenizers
pip install openai
pip install ijson
- run:
name: Run tests
command: |

View file

@ -13,7 +13,10 @@ from litellm._logging import (
verbose_logger,
json_logs,
_turn_on_json,
log_level,
)
from litellm.proxy._types import (
KeyManagementSystem,
KeyManagementSettings,
@ -736,6 +739,7 @@ from .utils import (
supports_function_calling,
supports_parallel_function_calling,
supports_vision,
supports_system_messages,
get_litellm_params,
acreate,
get_model_list,

View file

@ -12,7 +12,7 @@ if set_verbose is True:
)
json_logs = bool(os.getenv("JSON_LOGS", False))
# Create a handler for the logger (you may need to adapt this based on your needs)
log_level = os.getenv("LITELLM_LOG", "ERROR")
log_level = os.getenv("LITELLM_LOG", "DEBUG")
numeric_level: str = getattr(logging, log_level.upper())
handler = logging.StreamHandler()
handler.setLevel(numeric_level)

View file

@ -1,14 +1,22 @@
import types
import traceback
####################################
######### DEPRECATED FILE ##########
####################################
# logic moved to `vertex_httpx.py` #
import copy
import time
import traceback
import types
from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version
import litellm
from litellm import verbose_logger
from litellm.utils import Choices, Message, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
class GeminiError(Exception):
@ -186,8 +194,8 @@ def completion(
if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params)
if stream == True:
if acompletion == True:
if stream is True:
if acompletion is True:
async def async_streaming():
try:

View file

@ -1,41 +1,49 @@
# What is this?
## httpx client for vertex ai calls
## Initial implementation - covers gemini + image gen calls
from functools import partial
import os, types
import inspect
import json
from enum import Enum
import requests # type: ignore
import os
import time
from typing import Callable, Optional, Union, List, Any, Tuple
import types
import uuid
from enum import Enum
from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import ijson
import requests # type: ignore
import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
from litellm.types.llms.vertex_ai import (
ContentType,
SystemInstructions,
PartType,
RequestBody,
GenerateContentResponseBody,
FunctionCallingConfig,
FunctionDeclaration,
Tools,
ToolConfig,
GenerationConfig,
)
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
from litellm.types.utils import GenericStreamingChunk
from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionResponseMessage,
ChatCompletionUsageBlock,
)
from litellm.types.llms.vertex_ai import (
ContentType,
FunctionCallingConfig,
FunctionDeclaration,
GenerateContentResponseBody,
GenerationConfig,
PartType,
RequestBody,
SystemInstructions,
ToolConfig,
Tools,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
class VertexGeminiConfig:
@ -251,7 +259,7 @@ async def make_call(
raise VertexAIError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_bytes(chunk_size=2056)
streaming_response=response.aiter_bytes(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
@ -282,7 +290,7 @@ def make_sync_call(
raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056)
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
)
# LOGGING
@ -414,9 +422,11 @@ class VertexLLM(BaseLLM):
def load_auth(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[Any, str]:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.credentials import Credentials # type: ignore[import-untyped]
import google.auth as google_auth
from google.auth.credentials import Credentials # type: ignore[import-untyped]
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
if credentials is not None and isinstance(credentials, str):
import google.oauth2.service_account
@ -449,7 +459,9 @@ class VertexLLM(BaseLLM):
return creds, project_id
def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
credentials.refresh(Request())
@ -482,6 +494,50 @@ class VertexLLM(BaseLLM):
return self._credentials.token, self.project_id
def _get_token_and_url(
self,
model: str,
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
)
else:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
return auth_header, url
async def async_streaming(
self,
model: str,
@ -574,6 +630,9 @@ class VertexLLM(BaseLLM):
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
encoding,
logging_obj,
optional_params: dict,
@ -582,41 +641,58 @@ class VertexLLM(BaseLLM):
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
gemini_api_key: Optional[str],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
### SET RUNTIME ENDPOINT ###
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=stream,
custom_llm_provider=custom_llm_provider,
)
## TRANSFORMATION ##
try:
supports_system_message = litellm.supports_system_messages(
model=model, custom_llm_provider=custom_llm_provider
)
except Exception as e:
verbose_logger.error(
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
str(e)
)
)
supports_system_message = False
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[PartType] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = PartType(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
system_instructions = SystemInstructions(parts=system_content_blocks)
if supports_system_message is True:
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = PartType(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
content = _gemini_convert_messages_with_history(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params
)
data = RequestBody(system_instruction=system_instructions, contents=content)
data = RequestBody(contents=content)
if len(system_content_blocks) > 0:
system_instructions = SystemInstructions(parts=system_content_blocks)
data["system_instruction"] = system_instructions
if tools is not None:
data["tools"] = tools
if tool_choice is not None:
@ -626,8 +702,9 @@ class VertexLLM(BaseLLM):
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
## LOGGING
logging_obj.pre_call(
@ -642,6 +719,25 @@ class VertexLLM(BaseLLM):
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
### ASYNC STREAMING
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=json.dumps(data), # type: ignore
api_base=url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client, # type: ignore
)
### ASYNC COMPLETION
return self.async_completion(
model=model,
@ -853,9 +949,13 @@ class VertexLLM(BaseLLM):
class ModelResponseIterator:
def __init__(self, streaming_response):
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
self.response_iterator = iter(self.streaming_response)
if sync_stream:
self.response_iterator = iter(self.streaming_response)
self.events = ijson.sendable_list()
self.coro = ijson.items_coro(self.events, "item")
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
@ -907,10 +1007,21 @@ class ModelResponseIterator:
def __next__(self):
try:
chunk = next(self.response_iterator)
chunk = chunk.decode()
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
chunk = self.response_iterator.__next__()
self.coro.send(chunk)
if self.events:
event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
@ -924,9 +1035,20 @@ class ModelResponseIterator:
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
chunk = chunk.decode()
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
self.coro.send(chunk)
if self.events:
event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:

View file

@ -1884,43 +1884,7 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "gemini":
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
or get_secret("PALM_API_KEY") # older palm api key should also work
or litellm.api_key
)
# palm does not support streaming as yet :(
model_response = gemini.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=gemini_api_key,
logging_obj=logging,
acompletion=acompletion,
custom_prompt_dict=custom_prompt_dict,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
iter(model_response),
model,
custom_llm_provider="gemini",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai_beta":
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
@ -1938,6 +1902,14 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
or get_secret("PALM_API_KEY") # older palm api key should also work
or litellm.api_key
)
new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore
model=model,
@ -1951,9 +1923,11 @@ def completion(
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
gemini_api_key=gemini_api_key,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
)
elif custom_llm_provider == "vertex_ai":

View file

@ -7,66 +7,86 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache
import datetime as datetime_og
import logging, asyncio
import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport,
AsyncCustomHTTPTransport,
)
from litellm.utils import (
ModelResponse,
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
_is_region_eu,
)
import asyncio
import concurrent
import copy
from litellm._logging import verbose_router_logger
import datetime as datetime_og
import hashlib
import inspect
import json
import logging
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.router import (
Deployment,
ModelInfo,
LiteLLM_Params,
RouterErrors,
updateDeployment,
updateLiteLLMParams,
RetryPolicy,
AllowedFailsPolicy,
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
AssistantsTypedDict,
import random
import threading
import time
import traceback
import uuid
from collections import defaultdict
from datetime import datetime
from typing import (
Any,
BinaryIO,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
)
import httpx
import openai
from openai import AsyncOpenAI
from typing_extensions import overload
import litellm
from litellm._logging import verbose_router_logger
from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.custom_httpx.azure_dall_e_2 import (
AsyncCustomHTTPTransport,
CustomHTTPTransport,
)
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_utils.handle_error import send_llm_exception_alert
from litellm.scheduler import FlowItem, Scheduler
from litellm.types.llms.openai import (
AsyncCursorPage,
Assistant,
Thread,
AssistantToolParam,
AsyncCursorPage,
Attachment,
OpenAIMessage,
Run,
AssistantToolParam,
Thread,
)
from litellm.types.router import (
AlertingConfig,
AllowedFailsPolicy,
AssistantsTypedDict,
Deployment,
DeploymentTypedDict,
LiteLLM_Params,
ModelGroupInfo,
ModelInfo,
RetryPolicy,
RouterErrors,
updateDeployment,
updateLiteLLMParams,
)
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.utils import (
CustomStreamWrapper,
ModelResponse,
_is_region_eu,
calculate_max_parallel_requests,
get_utc_datetime,
)
from litellm.scheduler import Scheduler, FlowItem
from typing import Iterable
from litellm.router_utils.handle_error import send_llm_exception_alert
class Router:
@ -3114,6 +3134,7 @@ class Router:
# proxy support
import os
import httpx
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
@ -3800,6 +3821,7 @@ class Router:
litellm_provider=llm_provider,
mode="chat",
supported_openai_params=supported_openai_params,
supports_system_messages=None,
)
if model_group_info is None:

View file

@ -1,7 +1,10 @@
# conftest.py
import pytest, sys, os
import importlib
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
@ -18,6 +21,7 @@ def setup_and_teardown():
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
print("LITELLM_LOG - {}".format(os.getenv("LITELLM_LOG")))
import litellm
from litellm import Router

View file

@ -12,6 +12,8 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from unittest.mock import MagicMock, patch
import pytest
@ -3335,6 +3337,7 @@ def test_mistral_anyscale_stream():
#### Test A121 ###################
@pytest.mark.skip(reason="Local test")
def test_completion_ai21():
print("running ai21 j2light test")
litellm.set_verbose = True
@ -3390,10 +3393,21 @@ def test_completion_deep_infra_mistral():
# Gemini tests
def test_completion_gemini():
@pytest.mark.parametrize(
"model",
[
# "gemini-1.0-pro",
"gemini-1.5-pro",
# "gemini-1.5-flash",
],
)
def test_completion_gemini(model):
litellm.set_verbose = True
model_name = "gemini/gemini-1.5-pro-latest"
messages = [{"role": "user", "content": "Hey, how's it going?"}]
model_name = "gemini/{}".format(model)
messages = [
{"role": "system", "content": "Be a good bot!"},
{"role": "user", "content": "Hey, how's it going?"},
]
try:
response = completion(model=model_name, messages=messages)
# Add any assertions,here to check the response
@ -3485,7 +3499,7 @@ def test_completion_palm_stream():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="IBM closed account.")
@pytest.mark.skip(reason="Account deleted by IBM.")
def test_completion_watsonx():
litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2"
@ -3506,7 +3520,7 @@ def test_completion_watsonx():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="IBM closed account.")
@pytest.mark.skip(reason="Skip test. account deleted.")
def test_completion_stream_watsonx():
litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2"
@ -3574,7 +3588,7 @@ def test_unified_auth_params(provider, model, project, region_name, token):
assert value in translated_optional_params
@pytest.mark.skip(reason="IBM closed account.")
@pytest.mark.skip(reason="Local test")
@pytest.mark.asyncio
async def test_acompletion_watsonx():
litellm.set_verbose = True
@ -3595,7 +3609,7 @@ async def test_acompletion_watsonx():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="IBM closed account.")
@pytest.mark.skip(reason="Local test")
@pytest.mark.asyncio
async def test_acompletion_stream_watsonx():
litellm.set_verbose = True

View file

@ -1776,7 +1776,7 @@ def test_completion_sagemaker_stream():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="IBM closed account.")
@pytest.mark.skip(reason="Account deleted by IBM.")
def test_completion_watsonx_stream():
litellm.set_verbose = True
try:

View file

@ -1,13 +1,22 @@
import json
import time
import uuid
import json
import time
import uuid
from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict
from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False):
]
]
supported_openai_params: Required[Optional[List[str]]]
supports_system_messages: Optional[bool]
class GenericStreamingChunk(TypedDict):

View file

@ -1824,6 +1824,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
return False
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
Parameters:
model (str): The model name to be checked.
Returns:
bool: True if the model supports function calling, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
"""
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("supports_system_messages", False) is True:
return True
return False
except Exception:
raise Exception(
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
)
def supports_function_calling(model: str) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
@ -1839,7 +1865,7 @@ def supports_function_calling(model: str) -> bool:
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_function_calling", False):
if model_info.get("supports_function_calling", False) is True:
return True
return False
else:
@ -1863,7 +1889,7 @@ def supports_vision(model: str):
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_vision", False):
if model_info.get("supports_vision", False) is True:
return True
return False
else:
@ -1885,7 +1911,7 @@ def supports_parallel_function_calling(model: str):
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_parallel_function_calling", False):
if model_info.get("supports_parallel_function_calling", False) is True:
return True
return False
else:
@ -4320,14 +4346,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
)
if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model)
return {
"max_tokens": max_tokens, # type: ignore
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "huggingface",
"mode": "chat",
"supported_openai_params": supported_openai_params,
}
return ModelInfo(
max_tokens=max_tokens, # type: ignore
max_input_tokens=None,
max_output_tokens=None,
input_cost_per_token=0,
output_cost_per_token=0,
litellm_provider="huggingface",
mode="chat",
supported_openai_params=supported_openai_params,
supports_system_messages=None,
)
else:
"""
Check if: (in order of specificity)
@ -4348,7 +4377,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return _model_info
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
elif model in litellm.model_cost:
_model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params
@ -4362,7 +4411,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return _model_info
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
elif split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params
@ -4376,7 +4445,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return _model_info
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
else:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
@ -7030,7 +7119,9 @@ def exception_type(
exception_mapping_worked = True
if hasattr(original_exception, "request"):
raise APIConnectionError(
message=f"{str(original_exception)}",
message="{}\n{}".format(
str(original_exception), traceback.format_exc()
),
llm_provider=custom_llm_provider,
model=model,
request=original_exception.request,

View file

@ -27,6 +27,7 @@ jinja2 = "^3.1.2"
aiohttp = "*"
requests = "^2.31.0"
pydantic = "^2.0.0"
ijson = "*"
uvicorn = {version = "^0.22.0", optional = true}
gunicorn = {version = "^22.0.0", optional = true}

View file

@ -44,4 +44,5 @@ aiohttp==3.9.0 # for network calls
aioboto3==12.3.0 # for async sagemaker calls
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
pydantic==2.7.1 # proxy + openai req.
ijson==3.2.3 # for google ai studio streaming
####