forked from phoenix/litellm-mirror
Merge pull request #4246 from BerriAI/litellm_gemini_refactoring
feat(main.py): Gemini (Google AI Studio) - Support Function Calling, Inline images, etc.
This commit is contained in:
commit
af2917d655
14 changed files with 427 additions and 174 deletions
|
@ -65,6 +65,7 @@ jobs:
|
||||||
pip install "pydantic==2.7.1"
|
pip install "pydantic==2.7.1"
|
||||||
pip install "diskcache==5.6.1"
|
pip install "diskcache==5.6.1"
|
||||||
pip install "Pillow==10.3.0"
|
pip install "Pillow==10.3.0"
|
||||||
|
pip install "ijson==3.2.3"
|
||||||
- save_cache:
|
- save_cache:
|
||||||
paths:
|
paths:
|
||||||
- ./venv
|
- ./venv
|
||||||
|
@ -126,6 +127,7 @@ jobs:
|
||||||
pip install jinja2
|
pip install jinja2
|
||||||
pip install tokenizers
|
pip install tokenizers
|
||||||
pip install openai
|
pip install openai
|
||||||
|
pip install ijson
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
command: |
|
command: |
|
||||||
|
|
|
@ -13,7 +13,10 @@ from litellm._logging import (
|
||||||
verbose_logger,
|
verbose_logger,
|
||||||
json_logs,
|
json_logs,
|
||||||
_turn_on_json,
|
_turn_on_json,
|
||||||
|
log_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
KeyManagementSystem,
|
KeyManagementSystem,
|
||||||
KeyManagementSettings,
|
KeyManagementSettings,
|
||||||
|
@ -736,6 +739,7 @@ from .utils import (
|
||||||
supports_function_calling,
|
supports_function_calling,
|
||||||
supports_parallel_function_calling,
|
supports_parallel_function_calling,
|
||||||
supports_vision,
|
supports_vision,
|
||||||
|
supports_system_messages,
|
||||||
get_litellm_params,
|
get_litellm_params,
|
||||||
acreate,
|
acreate,
|
||||||
get_model_list,
|
get_model_list,
|
||||||
|
|
|
@ -12,7 +12,7 @@ if set_verbose is True:
|
||||||
)
|
)
|
||||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||||
log_level = os.getenv("LITELLM_LOG", "ERROR")
|
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
||||||
numeric_level: str = getattr(logging, log_level.upper())
|
numeric_level: str = getattr(logging, log_level.upper())
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setLevel(numeric_level)
|
handler.setLevel(numeric_level)
|
||||||
|
|
|
@ -1,14 +1,22 @@
|
||||||
import types
|
####################################
|
||||||
import traceback
|
######### DEPRECATED FILE ##########
|
||||||
|
####################################
|
||||||
|
# logic moved to `vertex_httpx.py` #
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
import types
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
|
||||||
import litellm
|
|
||||||
import httpx
|
import httpx
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .prompt_templates.factory import custom_prompt, get_system_prompt, prompt_factory
|
||||||
|
|
||||||
|
|
||||||
class GeminiError(Exception):
|
class GeminiError(Exception):
|
||||||
|
@ -186,8 +194,8 @@ def completion(
|
||||||
if _system_instruction and len(system_prompt) > 0:
|
if _system_instruction and len(system_prompt) > 0:
|
||||||
_params["system_instruction"] = system_prompt
|
_params["system_instruction"] = system_prompt
|
||||||
_model = genai.GenerativeModel(**_params)
|
_model = genai.GenerativeModel(**_params)
|
||||||
if stream == True:
|
if stream is True:
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
|
|
||||||
async def async_streaming():
|
async def async_streaming():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,41 +1,49 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## httpx client for vertex ai calls
|
## httpx client for vertex ai calls
|
||||||
## Initial implementation - covers gemini + image gen calls
|
## Initial implementation - covers gemini + image gen calls
|
||||||
from functools import partial
|
import inspect
|
||||||
import os, types
|
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
import os
|
||||||
import requests # type: ignore
|
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union, List, Any, Tuple
|
import types
|
||||||
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import ijson
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.litellm_logging
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
import litellm, uuid
|
|
||||||
import httpx, inspect # type: ignore
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .base import BaseLLM
|
|
||||||
from litellm.types.llms.vertex_ai import (
|
|
||||||
ContentType,
|
|
||||||
SystemInstructions,
|
|
||||||
PartType,
|
|
||||||
RequestBody,
|
|
||||||
GenerateContentResponseBody,
|
|
||||||
FunctionCallingConfig,
|
|
||||||
FunctionDeclaration,
|
|
||||||
Tools,
|
|
||||||
ToolConfig,
|
|
||||||
GenerationConfig,
|
|
||||||
)
|
|
||||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||||
from litellm.types.utils import GenericStreamingChunk
|
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
ContentType,
|
||||||
|
FunctionCallingConfig,
|
||||||
|
FunctionDeclaration,
|
||||||
|
GenerateContentResponseBody,
|
||||||
|
GenerationConfig,
|
||||||
|
PartType,
|
||||||
|
RequestBody,
|
||||||
|
SystemInstructions,
|
||||||
|
ToolConfig,
|
||||||
|
Tools,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class VertexGeminiConfig:
|
class VertexGeminiConfig:
|
||||||
|
@ -251,7 +259,7 @@ async def make_call(
|
||||||
raise VertexAIError(status_code=response.status_code, message=response.text)
|
raise VertexAIError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.aiter_bytes(chunk_size=2056)
|
streaming_response=response.aiter_bytes(), sync_stream=False
|
||||||
)
|
)
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -282,7 +290,7 @@ def make_sync_call(
|
||||||
raise VertexAIError(status_code=response.status_code, message=response.read())
|
raise VertexAIError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.iter_bytes(chunk_size=2056)
|
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -414,9 +422,11 @@ class VertexLLM(BaseLLM):
|
||||||
def load_auth(
|
def load_auth(
|
||||||
self, credentials: Optional[str], project_id: Optional[str]
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
) -> Tuple[Any, str]:
|
) -> Tuple[Any, str]:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
|
||||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
|
||||||
import google.auth as google_auth
|
import google.auth as google_auth
|
||||||
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
|
from google.auth.transport.requests import (
|
||||||
|
Request, # type: ignore[import-untyped]
|
||||||
|
)
|
||||||
|
|
||||||
if credentials is not None and isinstance(credentials, str):
|
if credentials is not None and isinstance(credentials, str):
|
||||||
import google.oauth2.service_account
|
import google.oauth2.service_account
|
||||||
|
@ -449,7 +459,9 @@ class VertexLLM(BaseLLM):
|
||||||
return creds, project_id
|
return creds, project_id
|
||||||
|
|
||||||
def refresh_auth(self, credentials: Any) -> None:
|
def refresh_auth(self, credentials: Any) -> None:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
from google.auth.transport.requests import (
|
||||||
|
Request, # type: ignore[import-untyped]
|
||||||
|
)
|
||||||
|
|
||||||
credentials.refresh(Request())
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
@ -482,6 +494,50 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
return self._credentials.token, self.project_id
|
return self._credentials.token, self.project_id
|
||||||
|
|
||||||
|
def _get_token_and_url(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||||
|
) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Internal function. Returns the token and url for the call.
|
||||||
|
|
||||||
|
Handles logic if it's google ai studio vs. vertex ai.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
token, url
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
_gemini_model_name = "models/{}".format(model)
|
||||||
|
auth_header = None
|
||||||
|
endpoint = "generateContent"
|
||||||
|
if stream is True:
|
||||||
|
endpoint = "streamGenerateContent"
|
||||||
|
|
||||||
|
url = (
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||||
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint = "generateContent"
|
||||||
|
if stream is True:
|
||||||
|
endpoint = "streamGenerateContent"
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||||
|
|
||||||
|
return auth_header, url
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -574,6 +630,9 @@ class VertexLLM(BaseLLM):
|
||||||
messages: list,
|
messages: list,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
|
custom_llm_provider: Literal[
|
||||||
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
], # if it's vertex_ai or gemini (google ai studio)
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
@ -582,41 +641,58 @@ class VertexLLM(BaseLLM):
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[str],
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
auth_header, vertex_project = self._ensure_access_token(
|
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
|
||||||
)
|
|
||||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
|
||||||
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
auth_header, url = self._get_token_and_url(
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
model=model,
|
||||||
|
gemini_api_key=gemini_api_key,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=stream,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
try:
|
||||||
|
supports_system_message = litellm.supports_system_messages(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
supports_system_message = False
|
||||||
# Separate system prompt from rest of message
|
# Separate system prompt from rest of message
|
||||||
system_prompt_indices = []
|
system_prompt_indices = []
|
||||||
system_content_blocks: List[PartType] = []
|
system_content_blocks: List[PartType] = []
|
||||||
for idx, message in enumerate(messages):
|
if supports_system_message is True:
|
||||||
if message["role"] == "system":
|
for idx, message in enumerate(messages):
|
||||||
_system_content_block = PartType(text=message["content"])
|
if message["role"] == "system":
|
||||||
system_content_blocks.append(_system_content_block)
|
_system_content_block = PartType(text=message["content"])
|
||||||
system_prompt_indices.append(idx)
|
system_content_blocks.append(_system_content_block)
|
||||||
if len(system_prompt_indices) > 0:
|
system_prompt_indices.append(idx)
|
||||||
for idx in reversed(system_prompt_indices):
|
if len(system_prompt_indices) > 0:
|
||||||
messages.pop(idx)
|
for idx in reversed(system_prompt_indices):
|
||||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
messages.pop(idx)
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
data = RequestBody(system_instruction=system_instructions, contents=content)
|
data = RequestBody(contents=content)
|
||||||
|
if len(system_content_blocks) > 0:
|
||||||
|
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||||
|
data["system_instruction"] = system_instructions
|
||||||
if tools is not None:
|
if tools is not None:
|
||||||
data["tools"] = tools
|
data["tools"] = tools
|
||||||
if tool_choice is not None:
|
if tool_choice is not None:
|
||||||
|
@ -626,8 +702,9 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
"Authorization": f"Bearer {auth_header}",
|
|
||||||
}
|
}
|
||||||
|
if auth_header is not None:
|
||||||
|
headers["Authorization"] = f"Bearer {auth_header}"
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -642,6 +719,25 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
if acompletion:
|
if acompletion:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if stream is True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=json.dumps(data), # type: ignore
|
||||||
|
api_base=url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client, # type: ignore
|
||||||
|
)
|
||||||
### ASYNC COMPLETION
|
### ASYNC COMPLETION
|
||||||
return self.async_completion(
|
return self.async_completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -853,9 +949,13 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response):
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
self.streaming_response = streaming_response
|
self.streaming_response = streaming_response
|
||||||
self.response_iterator = iter(self.streaming_response)
|
if sync_stream:
|
||||||
|
self.response_iterator = iter(self.streaming_response)
|
||||||
|
|
||||||
|
self.events = ijson.sendable_list()
|
||||||
|
self.coro = ijson.items_coro(self.events, "item")
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
try:
|
try:
|
||||||
|
@ -907,10 +1007,21 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
chunk = next(self.response_iterator)
|
chunk = self.response_iterator.__next__()
|
||||||
chunk = chunk.decode()
|
self.coro.send(chunk)
|
||||||
json_chunk = json.loads(chunk)
|
if self.events:
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
event = self.events[0]
|
||||||
|
json_chunk = event
|
||||||
|
self.events.clear()
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -924,9 +1035,20 @@ class ModelResponseIterator:
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
chunk = await self.async_response_iterator.__anext__()
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
chunk = chunk.decode()
|
self.coro.send(chunk)
|
||||||
json_chunk = json.loads(chunk)
|
if self.events:
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
event = self.events[0]
|
||||||
|
json_chunk = event
|
||||||
|
self.events.clear()
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
|
@ -1884,43 +1884,7 @@ def completion(
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "gemini":
|
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
||||||
gemini_api_key = (
|
|
||||||
api_key
|
|
||||||
or get_secret("GEMINI_API_KEY")
|
|
||||||
or get_secret("PALM_API_KEY") # older palm api key should also work
|
|
||||||
or litellm.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
# palm does not support streaming as yet :(
|
|
||||||
model_response = gemini.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
api_key=gemini_api_key,
|
|
||||||
logging_obj=logging,
|
|
||||||
acompletion=acompletion,
|
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"stream" in optional_params
|
|
||||||
and optional_params["stream"] == True
|
|
||||||
and acompletion == False
|
|
||||||
):
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
iter(model_response),
|
|
||||||
model,
|
|
||||||
custom_llm_provider="gemini",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
response = model_response
|
|
||||||
elif custom_llm_provider == "vertex_ai_beta":
|
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
optional_params.pop("vertex_project", None)
|
optional_params.pop("vertex_project", None)
|
||||||
or optional_params.pop("vertex_ai_project", None)
|
or optional_params.pop("vertex_ai_project", None)
|
||||||
|
@ -1938,6 +1902,14 @@ def completion(
|
||||||
or optional_params.pop("vertex_ai_credentials", None)
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
or get_secret("VERTEXAI_CREDENTIALS")
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gemini_api_key = (
|
||||||
|
api_key
|
||||||
|
or get_secret("GEMINI_API_KEY")
|
||||||
|
or get_secret("PALM_API_KEY") # older palm api key should also work
|
||||||
|
or litellm.api_key
|
||||||
|
)
|
||||||
|
|
||||||
new_params = deepcopy(optional_params)
|
new_params = deepcopy(optional_params)
|
||||||
response = vertex_chat_completion.completion( # type: ignore
|
response = vertex_chat_completion.completion( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1951,9 +1923,11 @@ def completion(
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
vertex_project=vertex_ai_project,
|
vertex_project=vertex_ai_project,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
|
gemini_api_key=gemini_api_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
|
|
@ -7,66 +7,86 @@
|
||||||
#
|
#
|
||||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
import copy, httpx
|
import asyncio
|
||||||
from datetime import datetime
|
import concurrent
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
|
||||||
from typing_extensions import overload
|
|
||||||
import random, threading, time, traceback, uuid
|
|
||||||
import litellm, openai, hashlib, json
|
|
||||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
|
||||||
import datetime as datetime_og
|
|
||||||
import logging, asyncio
|
|
||||||
import inspect, concurrent
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
from collections import defaultdict
|
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
|
||||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
|
||||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
|
||||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
|
||||||
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
|
||||||
CustomHTTPTransport,
|
|
||||||
AsyncCustomHTTPTransport,
|
|
||||||
)
|
|
||||||
from litellm.utils import (
|
|
||||||
ModelResponse,
|
|
||||||
CustomStreamWrapper,
|
|
||||||
get_utc_datetime,
|
|
||||||
calculate_max_parallel_requests,
|
|
||||||
_is_region_eu,
|
|
||||||
)
|
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
import datetime as datetime_og
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
import random
|
||||||
from litellm.types.router import (
|
import threading
|
||||||
Deployment,
|
import time
|
||||||
ModelInfo,
|
import traceback
|
||||||
LiteLLM_Params,
|
import uuid
|
||||||
RouterErrors,
|
from collections import defaultdict
|
||||||
updateDeployment,
|
from datetime import datetime
|
||||||
updateLiteLLMParams,
|
from typing import (
|
||||||
RetryPolicy,
|
Any,
|
||||||
AllowedFailsPolicy,
|
BinaryIO,
|
||||||
AlertingConfig,
|
Dict,
|
||||||
DeploymentTypedDict,
|
Iterable,
|
||||||
ModelGroupInfo,
|
List,
|
||||||
AssistantsTypedDict,
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import openai
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from typing_extensions import overload
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_router_logger
|
||||||
|
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
from litellm.llms.custom_httpx.azure_dall_e_2 import (
|
||||||
|
AsyncCustomHTTPTransport,
|
||||||
|
CustomHTTPTransport,
|
||||||
|
)
|
||||||
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
|
from litellm.router_utils.handle_error import send_llm_exception_alert
|
||||||
|
from litellm.scheduler import FlowItem, Scheduler
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AsyncCursorPage,
|
|
||||||
Assistant,
|
Assistant,
|
||||||
Thread,
|
AssistantToolParam,
|
||||||
|
AsyncCursorPage,
|
||||||
Attachment,
|
Attachment,
|
||||||
OpenAIMessage,
|
OpenAIMessage,
|
||||||
Run,
|
Run,
|
||||||
AssistantToolParam,
|
Thread,
|
||||||
|
)
|
||||||
|
from litellm.types.router import (
|
||||||
|
AlertingConfig,
|
||||||
|
AllowedFailsPolicy,
|
||||||
|
AssistantsTypedDict,
|
||||||
|
Deployment,
|
||||||
|
DeploymentTypedDict,
|
||||||
|
LiteLLM_Params,
|
||||||
|
ModelGroupInfo,
|
||||||
|
ModelInfo,
|
||||||
|
RetryPolicy,
|
||||||
|
RouterErrors,
|
||||||
|
updateDeployment,
|
||||||
|
updateLiteLLMParams,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||||
|
from litellm.utils import (
|
||||||
|
CustomStreamWrapper,
|
||||||
|
ModelResponse,
|
||||||
|
_is_region_eu,
|
||||||
|
calculate_max_parallel_requests,
|
||||||
|
get_utc_datetime,
|
||||||
)
|
)
|
||||||
from litellm.scheduler import Scheduler, FlowItem
|
|
||||||
from typing import Iterable
|
|
||||||
from litellm.router_utils.handle_error import send_llm_exception_alert
|
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -3114,6 +3134,7 @@ class Router:
|
||||||
|
|
||||||
# proxy support
|
# proxy support
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
|
||||||
|
@ -3800,6 +3821,7 @@ class Router:
|
||||||
litellm_provider=llm_provider,
|
litellm_provider=llm_provider,
|
||||||
mode="chat",
|
mode="chat",
|
||||||
supported_openai_params=supported_openai_params,
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_group_info is None:
|
if model_group_info is None:
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
# conftest.py
|
# conftest.py
|
||||||
|
|
||||||
import pytest, sys, os
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -18,6 +21,7 @@ def setup_and_teardown():
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the project directory to the system path
|
) # Adds the project directory to the system path
|
||||||
|
print("LITELLM_LOG - {}".format(os.getenv("LITELLM_LOG")))
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ import os
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import os
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -3335,6 +3337,7 @@ def test_mistral_anyscale_stream():
|
||||||
|
|
||||||
|
|
||||||
#### Test A121 ###################
|
#### Test A121 ###################
|
||||||
|
@pytest.mark.skip(reason="Local test")
|
||||||
def test_completion_ai21():
|
def test_completion_ai21():
|
||||||
print("running ai21 j2light test")
|
print("running ai21 j2light test")
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -3390,10 +3393,21 @@ def test_completion_deep_infra_mistral():
|
||||||
|
|
||||||
|
|
||||||
# Gemini tests
|
# Gemini tests
|
||||||
def test_completion_gemini():
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
# "gemini-1.0-pro",
|
||||||
|
"gemini-1.5-pro",
|
||||||
|
# "gemini-1.5-flash",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_gemini(model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "gemini/gemini-1.5-pro-latest"
|
model_name = "gemini/{}".format(model)
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
messages = [
|
||||||
|
{"role": "system", "content": "Be a good bot!"},
|
||||||
|
{"role": "user", "content": "Hey, how's it going?"},
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
response = completion(model=model_name, messages=messages)
|
response = completion(model=model_name, messages=messages)
|
||||||
# Add any assertions,here to check the response
|
# Add any assertions,here to check the response
|
||||||
|
@ -3485,7 +3499,7 @@ def test_completion_palm_stream():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="IBM closed account.")
|
@pytest.mark.skip(reason="Account deleted by IBM.")
|
||||||
def test_completion_watsonx():
|
def test_completion_watsonx():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||||
|
@ -3506,7 +3520,7 @@ def test_completion_watsonx():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="IBM closed account.")
|
@pytest.mark.skip(reason="Skip test. account deleted.")
|
||||||
def test_completion_stream_watsonx():
|
def test_completion_stream_watsonx():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||||
|
@ -3574,7 +3588,7 @@ def test_unified_auth_params(provider, model, project, region_name, token):
|
||||||
assert value in translated_optional_params
|
assert value in translated_optional_params
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="IBM closed account.")
|
@pytest.mark.skip(reason="Local test")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_watsonx():
|
async def test_acompletion_watsonx():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -3595,7 +3609,7 @@ async def test_acompletion_watsonx():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="IBM closed account.")
|
@pytest.mark.skip(reason="Local test")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_stream_watsonx():
|
async def test_acompletion_stream_watsonx():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -1776,7 +1776,7 @@ def test_completion_sagemaker_stream():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="IBM closed account.")
|
@pytest.mark.skip(reason="Account deleted by IBM.")
|
||||||
def test_completion_watsonx_stream():
|
def test_completion_watsonx_stream():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,13 +1,22 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from typing_extensions import Dict, Required, TypedDict, override
|
from typing_extensions import Dict, Required, TypedDict, override
|
||||||
|
|
||||||
|
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||||
|
from typing_extensions import Dict, Required, TypedDict, override
|
||||||
|
|
||||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||||
|
|
||||||
|
@ -60,6 +69,7 @@ class ModelInfo(TypedDict, total=False):
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
supported_openai_params: Required[Optional[List[str]]]
|
supported_openai_params: Required[Optional[List[str]]]
|
||||||
|
supports_system_messages: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
|
|
121
litellm/utils.py
121
litellm/utils.py
|
@ -1824,6 +1824,32 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given model supports function calling and return a boolean value.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The model name to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports function calling, False otherwise.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_info = litellm.get_model_info(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
if model_info.get("supports_system_messages", False) is True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
raise Exception(
|
||||||
|
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def supports_function_calling(model: str) -> bool:
|
def supports_function_calling(model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the given model supports function calling and return a boolean value.
|
Check if the given model supports function calling and return a boolean value.
|
||||||
|
@ -1839,7 +1865,7 @@ def supports_function_calling(model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
if model in litellm.model_cost:
|
if model in litellm.model_cost:
|
||||||
model_info = litellm.model_cost[model]
|
model_info = litellm.model_cost[model]
|
||||||
if model_info.get("supports_function_calling", False):
|
if model_info.get("supports_function_calling", False) is True:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -1863,7 +1889,7 @@ def supports_vision(model: str):
|
||||||
"""
|
"""
|
||||||
if model in litellm.model_cost:
|
if model in litellm.model_cost:
|
||||||
model_info = litellm.model_cost[model]
|
model_info = litellm.model_cost[model]
|
||||||
if model_info.get("supports_vision", False):
|
if model_info.get("supports_vision", False) is True:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -1885,7 +1911,7 @@ def supports_parallel_function_calling(model: str):
|
||||||
"""
|
"""
|
||||||
if model in litellm.model_cost:
|
if model in litellm.model_cost:
|
||||||
model_info = litellm.model_cost[model]
|
model_info = litellm.model_cost[model]
|
||||||
if model_info.get("supports_parallel_function_calling", False):
|
if model_info.get("supports_parallel_function_calling", False) is True:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -4320,14 +4346,17 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
)
|
)
|
||||||
if custom_llm_provider == "huggingface":
|
if custom_llm_provider == "huggingface":
|
||||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||||
return {
|
return ModelInfo(
|
||||||
"max_tokens": max_tokens, # type: ignore
|
max_tokens=max_tokens, # type: ignore
|
||||||
"input_cost_per_token": 0,
|
max_input_tokens=None,
|
||||||
"output_cost_per_token": 0,
|
max_output_tokens=None,
|
||||||
"litellm_provider": "huggingface",
|
input_cost_per_token=0,
|
||||||
"mode": "chat",
|
output_cost_per_token=0,
|
||||||
"supported_openai_params": supported_openai_params,
|
litellm_provider="huggingface",
|
||||||
}
|
mode="chat",
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
"""
|
"""
|
||||||
Check if: (in order of specificity)
|
Check if: (in order of specificity)
|
||||||
|
@ -4348,7 +4377,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return _model_info
|
return ModelInfo(
|
||||||
|
max_tokens=_model_info.get("max_tokens", None),
|
||||||
|
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||||
|
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||||
|
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||||
|
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
|
"input_cost_per_token_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||||
|
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
|
"output_cost_per_token_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
litellm_provider=_model_info.get(
|
||||||
|
"litellm_provider", custom_llm_provider
|
||||||
|
),
|
||||||
|
mode=_model_info.get("mode"),
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=_model_info.get(
|
||||||
|
"supports_system_messages", None
|
||||||
|
),
|
||||||
|
)
|
||||||
elif model in litellm.model_cost:
|
elif model in litellm.model_cost:
|
||||||
_model_info = litellm.model_cost[model]
|
_model_info = litellm.model_cost[model]
|
||||||
_model_info["supported_openai_params"] = supported_openai_params
|
_model_info["supported_openai_params"] = supported_openai_params
|
||||||
|
@ -4362,7 +4411,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return _model_info
|
return ModelInfo(
|
||||||
|
max_tokens=_model_info.get("max_tokens", None),
|
||||||
|
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||||
|
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||||
|
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||||
|
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
|
"input_cost_per_token_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||||
|
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
|
"output_cost_per_token_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
litellm_provider=_model_info.get(
|
||||||
|
"litellm_provider", custom_llm_provider
|
||||||
|
),
|
||||||
|
mode=_model_info.get("mode"),
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=_model_info.get(
|
||||||
|
"supports_system_messages", None
|
||||||
|
),
|
||||||
|
)
|
||||||
elif split_model in litellm.model_cost:
|
elif split_model in litellm.model_cost:
|
||||||
_model_info = litellm.model_cost[split_model]
|
_model_info = litellm.model_cost[split_model]
|
||||||
_model_info["supported_openai_params"] = supported_openai_params
|
_model_info["supported_openai_params"] = supported_openai_params
|
||||||
|
@ -4376,7 +4445,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return _model_info
|
return ModelInfo(
|
||||||
|
max_tokens=_model_info.get("max_tokens", None),
|
||||||
|
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||||
|
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||||
|
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||||
|
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
|
"input_cost_per_token_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||||
|
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
|
"output_cost_per_token_above_128k_tokens", None
|
||||||
|
),
|
||||||
|
litellm_provider=_model_info.get(
|
||||||
|
"litellm_provider", custom_llm_provider
|
||||||
|
),
|
||||||
|
mode=_model_info.get("mode"),
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
supports_system_messages=_model_info.get(
|
||||||
|
"supports_system_messages", None
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||||
|
@ -7030,7 +7119,9 @@ def exception_type(
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
if hasattr(original_exception, "request"):
|
if hasattr(original_exception, "request"):
|
||||||
raise APIConnectionError(
|
raise APIConnectionError(
|
||||||
message=f"{str(original_exception)}",
|
message="{}\n{}".format(
|
||||||
|
str(original_exception), traceback.format_exc()
|
||||||
|
),
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request,
|
request=original_exception.request,
|
||||||
|
|
|
@ -27,6 +27,7 @@ jinja2 = "^3.1.2"
|
||||||
aiohttp = "*"
|
aiohttp = "*"
|
||||||
requests = "^2.31.0"
|
requests = "^2.31.0"
|
||||||
pydantic = "^2.0.0"
|
pydantic = "^2.0.0"
|
||||||
|
ijson = "*"
|
||||||
|
|
||||||
uvicorn = {version = "^0.22.0", optional = true}
|
uvicorn = {version = "^0.22.0", optional = true}
|
||||||
gunicorn = {version = "^22.0.0", optional = true}
|
gunicorn = {version = "^22.0.0", optional = true}
|
||||||
|
|
|
@ -44,4 +44,5 @@ aiohttp==3.9.0 # for network calls
|
||||||
aioboto3==12.3.0 # for async sagemaker calls
|
aioboto3==12.3.0 # for async sagemaker calls
|
||||||
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
|
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
|
||||||
pydantic==2.7.1 # proxy + openai req.
|
pydantic==2.7.1 # proxy + openai req.
|
||||||
|
ijson==3.2.3 # for google ai studio streaming
|
||||||
####
|
####
|
Loading…
Add table
Add a link
Reference in a new issue