forked from phoenix/litellm-mirror
Merge pull request #4246 from BerriAI/litellm_gemini_refactoring
feat(main.py): Gemini (Google AI Studio) - Support Function Calling, Inline images, etc.
This commit is contained in:
commit
af2917d655
14 changed files with 427 additions and 174 deletions
|
@ -65,6 +65,7 @@ jobs:
|
|||
pip install "pydantic==2.7.1"
|
||||
pip install "diskcache==5.6.1"
|
||||
pip install "Pillow==10.3.0"
|
||||
pip install "ijson==3.2.3"
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
@ -126,6 +127,7 @@ jobs:
|
|||
pip install jinja2
|
||||
pip install tokenizers
|
||||
pip install openai
|
||||
pip install ijson
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
|
|
|
@ -13,7 +13,10 @@ from litellm._logging import (
|
|||
verbose_logger,
|
||||
json_logs,
|
||||
_turn_on_json,
|
||||
log_level,
|
||||
)
|
||||
|
||||
|
||||
from litellm.proxy._types import (
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
|
@ -736,6 +739,7 @@ from .utils import (
|
|||
supports_function_calling,
|
||||
supports_parallel_function_calling,
|
||||
supports_vision,
|
||||
supports_system_messages,
|
||||
get_litellm_params,
|
||||
acreate,
|
||||
get_model_list,
|
||||
|
|
|
@ -12,7 +12,7 @@ if set_verbose is True:
|
|||
)
|
||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||
log_level = os.getenv("LITELLM_LOG", "ERROR")
|
||||
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
||||
numeric_level: str = getattr(logging, log_level.upper())
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(numeric_level)
|
||||
|
|
|
@ -1,14 +1,22 @@
|
|||
import types
|
||||
import traceback
|
||||
####################################
|
||||
######### DEPRECATED FILE ##########
|
||||
####################################
|
||||
# logic moved to `vertex_httpx.py` #
|
||||
|
||||
import copy
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
|
||||
import httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
|
||||
from packaging.version import Version
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
|
||||
|
||||
|
||||
class GeminiError(Exception):
|
||||
|
@ -186,8 +194,8 @@ def completion(
|
|||
if _system_instruction and len(system_prompt) > 0:
|
||||
_params["system_instruction"] = system_prompt
|
||||
_model = genai.GenerativeModel(**_params)
|
||||
if stream == True:
|
||||
if acompletion == True:
|
||||
if stream is True:
|
||||
if acompletion is True:
|
||||
|
||||
async def async_streaming():
|
||||
try:
|
||||
|
|
|
@ -1,41 +1,49 @@
|
|||
# What is this?
|
||||
## httpx client for vertex ai calls
|
||||
## Initial implementation - covers gemini + image gen calls
|
||||
from functools import partial
|
||||
import os, types
|
||||
import inspect
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests # type: ignore
|
||||
import os
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List, Any, Tuple
|
||||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import ijson
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.litellm_logging
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect # type: ignore
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
ContentType,
|
||||
SystemInstructions,
|
||||
PartType,
|
||||
RequestBody,
|
||||
GenerateContentResponseBody,
|
||||
FunctionCallingConfig,
|
||||
FunctionDeclaration,
|
||||
Tools,
|
||||
ToolConfig,
|
||||
GenerationConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionUsageBlock,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
ContentType,
|
||||
FunctionCallingConfig,
|
||||
FunctionDeclaration,
|
||||
GenerateContentResponseBody,
|
||||
GenerationConfig,
|
||||
PartType,
|
||||
RequestBody,
|
||||
SystemInstructions,
|
||||
ToolConfig,
|
||||
Tools,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class VertexGeminiConfig:
|
||||
|
@ -251,7 +259,7 @@ async def make_call(
|
|||
raise VertexAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_bytes(chunk_size=2056)
|
||||
streaming_response=response.aiter_bytes(), sync_stream=False
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -282,7 +290,7 @@ def make_sync_call(
|
|||
raise VertexAIError(status_code=response.status_code, message=response.read())
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_bytes(chunk_size=2056)
|
||||
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
|
@ -414,9 +422,11 @@ class VertexLLM(BaseLLM):
|
|||
def load_auth(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[Any, str]:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
import google.auth as google_auth
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
if credentials is not None and isinstance(credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
@ -449,7 +459,9 @@ class VertexLLM(BaseLLM):
|
|||
return creds, project_id
|
||||
|
||||
def refresh_auth(self, credentials: Any) -> None:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
|
@ -482,6 +494,50 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
return self._credentials.token, self.project_id
|
||||
|
||||
def _get_token_and_url(
|
||||
self,
|
||||
model: str,
|
||||
gemini_api_key: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
||||
Handles logic if it's google ai studio vs. vertex ai.
|
||||
|
||||
Returns
|
||||
token, url
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
_gemini_model_name = "models/{}".format(model)
|
||||
auth_header = None
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
|
||||
url = (
|
||||
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
)
|
||||
else:
|
||||
auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
return auth_header, url
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -574,6 +630,9 @@ class VertexLLM(BaseLLM):
|
|||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
|
@ -582,41 +641,58 @@ class VertexLLM(BaseLLM):
|
|||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
gemini_api_key: Optional[str],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
try:
|
||||
supports_system_message = litellm.supports_system_messages(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
supports_system_message = False
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_content_blocks: List[PartType] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block = PartType(text=message["content"])
|
||||
system_content_blocks.append(_system_content_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||
if supports_system_message is True:
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block = PartType(text=message["content"])
|
||||
system_content_blocks.append(_system_content_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||
**optional_params
|
||||
)
|
||||
data = RequestBody(system_instruction=system_instructions, contents=content)
|
||||
data = RequestBody(contents=content)
|
||||
if len(system_content_blocks) > 0:
|
||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||
data["system_instruction"] = system_instructions
|
||||
if tools is not None:
|
||||
data["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
|
@ -626,8 +702,9 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
if auth_header is not None:
|
||||
headers["Authorization"] = f"Bearer {auth_header}"
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -642,6 +719,25 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
### ASYNC STREAMING
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=json.dumps(data), # type: ignore
|
||||
api_base=url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client, # type: ignore
|
||||
)
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
|
@ -853,9 +949,13 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response):
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = iter(self.streaming_response)
|
||||
if sync_stream:
|
||||
self.response_iterator = iter(self.streaming_response)
|
||||
|
||||
self.events = ijson.sendable_list()
|
||||
self.coro = ijson.items_coro(self.events, "item")
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
|
@ -907,10 +1007,21 @@ class ModelResponseIterator:
|
|||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = next(self.response_iterator)
|
||||
chunk = chunk.decode()
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
chunk = self.response_iterator.__next__()
|
||||
self.coro.send(chunk)
|
||||
if self.events:
|
||||
event = self.events[0]
|
||||
json_chunk = event
|
||||
self.events.clear()
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
|
@ -924,9 +1035,20 @@ class ModelResponseIterator:
|
|||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
chunk = chunk.decode()
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
self.coro.send(chunk)
|
||||
if self.events:
|
||||
event = self.events[0]
|
||||
json_chunk = event
|
||||
self.events.clear()
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
|
|
|
@ -1884,43 +1884,7 @@ def completion(
|
|||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "gemini":
|
||||
gemini_api_key = (
|
||||
api_key
|
||||
or get_secret("GEMINI_API_KEY")
|
||||
or get_secret("PALM_API_KEY") # older palm api key should also work
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
# palm does not support streaming as yet :(
|
||||
model_response = gemini.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=gemini_api_key,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
):
|
||||
response = CustomStreamWrapper(
|
||||
iter(model_response),
|
||||
model,
|
||||
custom_llm_provider="gemini",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
or optional_params.pop("vertex_ai_project", None)
|
||||
|
@ -1938,6 +1902,14 @@ def completion(
|
|||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
gemini_api_key = (
|
||||
api_key
|
||||
or get_secret("GEMINI_API_KEY")
|
||||
or get_secret("PALM_API_KEY") # older palm api key should also work
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
new_params = deepcopy(optional_params)
|
||||
response = vertex_chat_completion.completion( # type: ignore
|
||||
model=model,
|
||||
|
@ -1951,9 +1923,11 @@ def completion(
|
|||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
gemini_api_key=gemini_api_key,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
|
|
|
@ -7,66 +7,86 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
||||
from typing_extensions import overload
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai, hashlib, json
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
import datetime as datetime_og
|
||||
import logging, asyncio
|
||||
import inspect, concurrent
|
||||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||
CustomHTTPTransport,
|
||||
AsyncCustomHTTPTransport,
|
||||
)
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
CustomStreamWrapper,
|
||||
get_utc_datetime,
|
||||
calculate_max_parallel_requests,
|
||||
_is_region_eu,
|
||||
)
|
||||
import asyncio
|
||||
import concurrent
|
||||
import copy
|
||||
from litellm._logging import verbose_router_logger
|
||||
import datetime as datetime_og
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
from litellm.types.router import (
|
||||
Deployment,
|
||||
ModelInfo,
|
||||
LiteLLM_Params,
|
||||
RouterErrors,
|
||||
updateDeployment,
|
||||
updateLiteLLMParams,
|
||||
RetryPolicy,
|
||||
AllowedFailsPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
ModelGroupInfo,
|
||||
AssistantsTypedDict,
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from typing_extensions import overload
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||
AsyncCustomHTTPTransport,
|
||||
CustomHTTPTransport,
|
||||
)
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
from litellm.router_utils.handle_error import send_llm_exception_alert
|
||||
from litellm.scheduler import FlowItem, Scheduler
|
||||
from litellm.types.llms.openai import (
|
||||
AsyncCursorPage,
|
||||
Assistant,
|
||||
Thread,
|
||||
AssistantToolParam,
|
||||
AsyncCursorPage,
|
||||
Attachment,
|
||||
OpenAIMessage,
|
||||
Run,
|
||||
AssistantToolParam,
|
||||
Thread,
|
||||
)
|
||||
from litellm.types.router import (
|
||||
AlertingConfig,
|
||||
AllowedFailsPolicy,
|
||||
AssistantsTypedDict,
|
||||
Deployment,
|
||||
DeploymentTypedDict,
|
||||
LiteLLM_Params,
|
||||
ModelGroupInfo,
|
||||
ModelInfo,
|
||||
RetryPolicy,
|
||||
RouterErrors,
|
||||
updateDeployment,
|
||||
updateLiteLLMParams,
|
||||
)
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
ModelResponse,
|
||||
_is_region_eu,
|
||||
calculate_max_parallel_requests,
|
||||
get_utc_datetime,
|
||||
)
|
||||
from litellm.scheduler import Scheduler, FlowItem
|
||||
from typing import Iterable
|
||||
from litellm.router_utils.handle_error import send_llm_exception_alert
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -3114,6 +3134,7 @@ class Router:
|
|||
|
||||
# proxy support
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
||||
|
@ -3800,6 +3821,7 @@ class Router:
|
|||
litellm_provider=llm_provider,
|
||||
mode="chat",
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=None,
|
||||
)
|
||||
|
||||
if model_group_info is None:
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
# conftest.py
|
||||
|
||||
import pytest, sys, os
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -18,6 +21,7 @@ def setup_and_teardown():
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the project directory to the system path
|
||||
print("LITELLM_LOG - {}".format(os.getenv("LITELLM_LOG")))
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ import os
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -3335,6 +3337,7 @@ def test_mistral_anyscale_stream():
|
|||
|
||||
|
||||
#### Test A121 ###################
|
||||
@pytest.mark.skip(reason="Local test")
|
||||
def test_completion_ai21():
|
||||
print("running ai21 j2light test")
|
||||
litellm.set_verbose = True
|
||||
|
@ -3390,10 +3393,21 @@ def test_completion_deep_infra_mistral():
|
|||
|
||||
|
||||
# Gemini tests
|
||||
def test_completion_gemini():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "gemini-1.0-pro",
|
||||
"gemini-1.5-pro",
|
||||
# "gemini-1.5-flash",
|
||||
],
|
||||
)
|
||||
def test_completion_gemini(model):
|
||||
litellm.set_verbose = True
|
||||
model_name = "gemini/gemini-1.5-pro-latest"
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
model_name = "gemini/{}".format(model)
|
||||
messages = [
|
||||
{"role": "system", "content": "Be a good bot!"},
|
||||
{"role": "user", "content": "Hey, how's it going?"},
|
||||
]
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages)
|
||||
# Add any assertions,here to check the response
|
||||
|
@ -3485,7 +3499,7 @@ def test_completion_palm_stream():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="IBM closed account.")
|
||||
@pytest.mark.skip(reason="Account deleted by IBM.")
|
||||
def test_completion_watsonx():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||
|
@ -3506,7 +3520,7 @@ def test_completion_watsonx():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="IBM closed account.")
|
||||
@pytest.mark.skip(reason="Skip test. account deleted.")
|
||||
def test_completion_stream_watsonx():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||
|
@ -3574,7 +3588,7 @@ def test_unified_auth_params(provider, model, project, region_name, token):
|
|||
assert value in translated_optional_params
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="IBM closed account.")
|
||||
@pytest.mark.skip(reason="Local test")
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_watsonx():
|
||||
litellm.set_verbose = True
|
||||
|
@ -3595,7 +3609,7 @@ async def test_acompletion_watsonx():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="IBM closed account.")
|
||||
@pytest.mark.skip(reason="Local test")
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_stream_watsonx():
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -1776,7 +1776,7 @@ def test_completion_sagemaker_stream():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="IBM closed account.")
|
||||
@pytest.mark.skip(reason="Account deleted by IBM.")
|
||||
def test_completion_watsonx_stream():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
|
|
|
@ -1,13 +1,22 @@
|
|||
import json
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
|
||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
|
||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||
|
||||
|
@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False):
|
|||
]
|
||||
]
|
||||
supported_openai_params: Required[Optional[List[str]]]
|
||||
supports_system_messages: Optional[bool]
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
|
|
121
litellm/utils.py
121
litellm/utils.py
|
@ -1824,6 +1824,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
||||
Parameters:
|
||||
model (str): The model name to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||
"""
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if model_info.get("supports_system_messages", False) is True:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||
)
|
||||
|
||||
|
||||
def supports_function_calling(model: str) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
@ -1839,7 +1865,7 @@ def supports_function_calling(model: str) -> bool:
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_function_calling", False):
|
||||
if model_info.get("supports_function_calling", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -1863,7 +1889,7 @@ def supports_vision(model: str):
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_vision", False):
|
||||
if model_info.get("supports_vision", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -1885,7 +1911,7 @@ def supports_parallel_function_calling(model: str):
|
|||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_parallel_function_calling", False):
|
||||
if model_info.get("supports_parallel_function_calling", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
|
@ -4320,14 +4346,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
)
|
||||
if custom_llm_provider == "huggingface":
|
||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||
return {
|
||||
"max_tokens": max_tokens, # type: ignore
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"litellm_provider": "huggingface",
|
||||
"mode": "chat",
|
||||
"supported_openai_params": supported_openai_params,
|
||||
}
|
||||
return ModelInfo(
|
||||
max_tokens=max_tokens, # type: ignore
|
||||
max_input_tokens=None,
|
||||
max_output_tokens=None,
|
||||
input_cost_per_token=0,
|
||||
output_cost_per_token=0,
|
||||
litellm_provider="huggingface",
|
||||
mode="chat",
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=None,
|
||||
)
|
||||
else:
|
||||
"""
|
||||
Check if: (in order of specificity)
|
||||
|
@ -4348,7 +4377,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
pass
|
||||
else:
|
||||
raise Exception
|
||||
return _model_info
|
||||
return ModelInfo(
|
||||
max_tokens=_model_info.get("max_tokens", None),
|
||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
mode=_model_info.get("mode"),
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=_model_info.get(
|
||||
"supports_system_messages", None
|
||||
),
|
||||
)
|
||||
elif model in litellm.model_cost:
|
||||
_model_info = litellm.model_cost[model]
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
|
@ -4362,7 +4411,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
pass
|
||||
else:
|
||||
raise Exception
|
||||
return _model_info
|
||||
return ModelInfo(
|
||||
max_tokens=_model_info.get("max_tokens", None),
|
||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
mode=_model_info.get("mode"),
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=_model_info.get(
|
||||
"supports_system_messages", None
|
||||
),
|
||||
)
|
||||
elif split_model in litellm.model_cost:
|
||||
_model_info = litellm.model_cost[split_model]
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
|
@ -4376,7 +4445,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
pass
|
||||
else:
|
||||
raise Exception
|
||||
return _model_info
|
||||
return ModelInfo(
|
||||
max_tokens=_model_info.get("max_tokens", None),
|
||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
mode=_model_info.get("mode"),
|
||||
supported_openai_params=supported_openai_params,
|
||||
supports_system_messages=_model_info.get(
|
||||
"supports_system_messages", None
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||
|
@ -7030,7 +7119,9 @@ def exception_type(
|
|||
exception_mapping_worked = True
|
||||
if hasattr(original_exception, "request"):
|
||||
raise APIConnectionError(
|
||||
message=f"{str(original_exception)}",
|
||||
message="{}\n{}".format(
|
||||
str(original_exception), traceback.format_exc()
|
||||
),
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
request=original_exception.request,
|
||||
|
|
|
@ -27,6 +27,7 @@ jinja2 = "^3.1.2"
|
|||
aiohttp = "*"
|
||||
requests = "^2.31.0"
|
||||
pydantic = "^2.0.0"
|
||||
ijson = "*"
|
||||
|
||||
uvicorn = {version = "^0.22.0", optional = true}
|
||||
gunicorn = {version = "^22.0.0", optional = true}
|
||||
|
|
|
@ -44,4 +44,5 @@ aiohttp==3.9.0 # for network calls
|
|||
aioboto3==12.3.0 # for async sagemaker calls
|
||||
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
|
||||
pydantic==2.7.1 # proxy + openai req.
|
||||
ijson==3.2.3 # for google ai studio streaming
|
||||
####
|
Loading…
Add table
Add a link
Reference in a new issue