forked from phoenix/litellm-mirror
feat(main.py): initial commit - refactoring google ai studio to just use vertex httpx
Uses the same calling logic for google ai studio/vertex ai. Simplifies logic, gives google ai studio integration all of vertex ai features.
This commit is contained in:
parent
fcea4c22ad
commit
be66800a98
7 changed files with 216 additions and 137 deletions
|
@ -13,7 +13,10 @@ from litellm._logging import (
|
||||||
verbose_logger,
|
verbose_logger,
|
||||||
json_logs,
|
json_logs,
|
||||||
_turn_on_json,
|
_turn_on_json,
|
||||||
|
log_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
KeyManagementSystem,
|
KeyManagementSystem,
|
||||||
KeyManagementSettings,
|
KeyManagementSettings,
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import logging, os, json
|
import json
|
||||||
from logging import Formatter
|
import logging
|
||||||
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
from logging import Formatter
|
||||||
|
|
||||||
set_verbose = False
|
set_verbose = False
|
||||||
|
|
||||||
|
|
|
@ -1,41 +1,47 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## httpx client for vertex ai calls
|
## httpx client for vertex ai calls
|
||||||
## Initial implementation - covers gemini + image gen calls
|
## Initial implementation - covers gemini + image gen calls
|
||||||
from functools import partial
|
import inspect
|
||||||
import os, types
|
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
import os
|
||||||
import requests # type: ignore
|
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union, List, Any, Tuple
|
import types
|
||||||
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.litellm_logging
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
import litellm, uuid
|
|
||||||
import httpx, inspect # type: ignore
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .base import BaseLLM
|
|
||||||
from litellm.types.llms.vertex_ai import (
|
|
||||||
ContentType,
|
|
||||||
SystemInstructions,
|
|
||||||
PartType,
|
|
||||||
RequestBody,
|
|
||||||
GenerateContentResponseBody,
|
|
||||||
FunctionCallingConfig,
|
|
||||||
FunctionDeclaration,
|
|
||||||
Tools,
|
|
||||||
ToolConfig,
|
|
||||||
GenerationConfig,
|
|
||||||
)
|
|
||||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||||
from litellm.types.utils import GenericStreamingChunk
|
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
ContentType,
|
||||||
|
FunctionCallingConfig,
|
||||||
|
FunctionDeclaration,
|
||||||
|
GenerateContentResponseBody,
|
||||||
|
GenerationConfig,
|
||||||
|
PartType,
|
||||||
|
RequestBody,
|
||||||
|
SystemInstructions,
|
||||||
|
ToolConfig,
|
||||||
|
Tools,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class VertexGeminiConfig:
|
class VertexGeminiConfig:
|
||||||
|
@ -414,9 +420,11 @@ class VertexLLM(BaseLLM):
|
||||||
def load_auth(
|
def load_auth(
|
||||||
self, credentials: Optional[str], project_id: Optional[str]
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
) -> Tuple[Any, str]:
|
) -> Tuple[Any, str]:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
|
||||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
|
||||||
import google.auth as google_auth
|
import google.auth as google_auth
|
||||||
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
|
from google.auth.transport.requests import (
|
||||||
|
Request, # type: ignore[import-untyped]
|
||||||
|
)
|
||||||
|
|
||||||
if credentials is not None and isinstance(credentials, str):
|
if credentials is not None and isinstance(credentials, str):
|
||||||
import google.oauth2.service_account
|
import google.oauth2.service_account
|
||||||
|
@ -449,7 +457,9 @@ class VertexLLM(BaseLLM):
|
||||||
return creds, project_id
|
return creds, project_id
|
||||||
|
|
||||||
def refresh_auth(self, credentials: Any) -> None:
|
def refresh_auth(self, credentials: Any) -> None:
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
from google.auth.transport.requests import (
|
||||||
|
Request, # type: ignore[import-untyped]
|
||||||
|
)
|
||||||
|
|
||||||
credentials.refresh(Request())
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
@ -482,6 +492,50 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
return self._credentials.token, self.project_id
|
return self._credentials.token, self.project_id
|
||||||
|
|
||||||
|
def _get_token_and_url(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||||
|
) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Internal function. Returns the token and url for the call.
|
||||||
|
|
||||||
|
Handles logic if it's google ai studio vs. vertex ai.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
token, url
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
_gemini_model_name = "models/{}".format(model)
|
||||||
|
auth_header = None
|
||||||
|
endpoint = "generateContent"
|
||||||
|
if stream is True:
|
||||||
|
endpoint = "streamGenerateContent"
|
||||||
|
|
||||||
|
url = (
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||||
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint = "generateContent"
|
||||||
|
if stream is True:
|
||||||
|
endpoint = "streamGenerateContent"
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||||
|
|
||||||
|
return auth_header, url
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -574,6 +628,9 @@ class VertexLLM(BaseLLM):
|
||||||
messages: list,
|
messages: list,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
|
custom_llm_provider: Literal[
|
||||||
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
], # if it's vertex_ai or gemini (google ai studio)
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
@ -582,20 +639,23 @@ class VertexLLM(BaseLLM):
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[str],
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
auth_header, vertex_project = self._ensure_access_token(
|
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
|
||||||
)
|
|
||||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
|
||||||
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
auth_header, url = self._get_token_and_url(
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
model=model,
|
||||||
|
gemini_api_key=gemini_api_key,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=stream,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
# Separate system prompt from rest of message
|
# Separate system prompt from rest of message
|
||||||
|
@ -609,14 +669,16 @@ class VertexLLM(BaseLLM):
|
||||||
if len(system_prompt_indices) > 0:
|
if len(system_prompt_indices) > 0:
|
||||||
for idx in reversed(system_prompt_indices):
|
for idx in reversed(system_prompt_indices):
|
||||||
messages.pop(idx)
|
messages.pop(idx)
|
||||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
data = RequestBody(system_instruction=system_instructions, contents=content)
|
data = RequestBody(contents=content)
|
||||||
|
if len(system_content_blocks) > 0:
|
||||||
|
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||||
|
data["system_instruction"] = system_instructions
|
||||||
if tools is not None:
|
if tools is not None:
|
||||||
data["tools"] = tools
|
data["tools"] = tools
|
||||||
if tool_choice is not None:
|
if tool_choice is not None:
|
||||||
|
@ -626,8 +688,9 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
"Authorization": f"Bearer {auth_header}",
|
|
||||||
}
|
}
|
||||||
|
if auth_header is not None:
|
||||||
|
headers["Authorization"] = f"Bearer {auth_header}"
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
|
173
litellm/main.py
173
litellm/main.py
|
@ -7,107 +7,130 @@
|
||||||
#
|
#
|
||||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
import asyncio
|
||||||
from typing import Any, Literal, Union, BinaryIO
|
import contextvars
|
||||||
from typing_extensions import overload
|
import datetime
|
||||||
from functools import partial
|
import inspect
|
||||||
|
import json
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
BinaryIO,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import dotenv
|
||||||
import httpx
|
import httpx
|
||||||
|
import openai
|
||||||
|
import tiktoken
|
||||||
|
from typing_extensions import overload
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from ._logging import verbose_logger
|
|
||||||
from litellm import ( # type: ignore
|
from litellm import ( # type: ignore
|
||||||
|
Logging,
|
||||||
client,
|
client,
|
||||||
exception_type,
|
exception_type,
|
||||||
get_optional_params,
|
|
||||||
get_litellm_params,
|
get_litellm_params,
|
||||||
Logging,
|
get_optional_params,
|
||||||
)
|
)
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
get_secret,
|
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
read_config_args,
|
Usage,
|
||||||
completion_with_fallbacks,
|
|
||||||
get_llm_provider,
|
|
||||||
get_api_key,
|
|
||||||
mock_completion_streaming_obj,
|
|
||||||
async_mock_completion_streaming_obj,
|
async_mock_completion_streaming_obj,
|
||||||
|
completion_with_fallbacks,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
token_counter,
|
|
||||||
create_pretrained_tokenizer,
|
create_pretrained_tokenizer,
|
||||||
create_tokenizer,
|
create_tokenizer,
|
||||||
Usage,
|
get_api_key,
|
||||||
|
get_llm_provider,
|
||||||
get_optional_params_embeddings,
|
get_optional_params_embeddings,
|
||||||
get_optional_params_image_gen,
|
get_optional_params_image_gen,
|
||||||
|
get_secret,
|
||||||
|
mock_completion_streaming_obj,
|
||||||
|
read_config_args,
|
||||||
supports_httpx_timeout,
|
supports_httpx_timeout,
|
||||||
|
token_counter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ._logging import verbose_logger
|
||||||
|
from .caching import disable_cache, enable_cache, update_cache
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic_text,
|
|
||||||
together_ai,
|
|
||||||
ai21,
|
ai21,
|
||||||
sagemaker,
|
|
||||||
bedrock,
|
|
||||||
triton,
|
|
||||||
huggingface_restapi,
|
|
||||||
replicate,
|
|
||||||
aleph_alpha,
|
aleph_alpha,
|
||||||
nlp_cloud,
|
anthropic_text,
|
||||||
baseten,
|
baseten,
|
||||||
vllm,
|
bedrock,
|
||||||
ollama,
|
|
||||||
ollama_chat,
|
|
||||||
cloudflare,
|
|
||||||
clarifai,
|
clarifai,
|
||||||
|
cloudflare,
|
||||||
cohere,
|
cohere,
|
||||||
cohere_chat,
|
cohere_chat,
|
||||||
petals,
|
gemini,
|
||||||
|
huggingface_restapi,
|
||||||
|
maritalk,
|
||||||
|
nlp_cloud,
|
||||||
|
ollama,
|
||||||
|
ollama_chat,
|
||||||
oobabooga,
|
oobabooga,
|
||||||
openrouter,
|
openrouter,
|
||||||
palm,
|
palm,
|
||||||
gemini,
|
petals,
|
||||||
|
replicate,
|
||||||
|
sagemaker,
|
||||||
|
together_ai,
|
||||||
|
triton,
|
||||||
vertex_ai,
|
vertex_ai,
|
||||||
vertex_ai_anthropic,
|
vertex_ai_anthropic,
|
||||||
maritalk,
|
vllm,
|
||||||
watsonx,
|
watsonx,
|
||||||
)
|
)
|
||||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
|
||||||
from .llms.azure import AzureChatCompletion
|
|
||||||
from .llms.databricks import DatabricksChatCompletion
|
|
||||||
from .llms.azure_text import AzureTextCompletion
|
|
||||||
from .llms.anthropic import AnthropicChatCompletion
|
from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
|
from .llms.azure import AzureChatCompletion
|
||||||
|
from .llms.azure_text import AzureTextCompletion
|
||||||
|
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
||||||
|
from .llms.databricks import DatabricksChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
|
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
|
||||||
from .llms.vertex_httpx import VertexLLM
|
|
||||||
from .llms.triton import TritonChatCompletion
|
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
map_system_message_pt,
|
map_system_message_pt,
|
||||||
|
prompt_factory,
|
||||||
)
|
)
|
||||||
import tiktoken
|
from .llms.triton import TritonChatCompletion
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from .llms.vertex_httpx import VertexLLM
|
||||||
from typing import Callable, List, Optional, Dict, Union, Mapping
|
|
||||||
from .caching import enable_cache, disable_cache, update_cache
|
|
||||||
from .types.llms.openai import HttpxBinaryResponseContent
|
from .types.llms.openai import HttpxBinaryResponseContent
|
||||||
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
get_secret,
|
Choices,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
TextCompletionStreamWrapper,
|
|
||||||
ModelResponse,
|
|
||||||
TextCompletionResponse,
|
|
||||||
TextChoices,
|
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
read_config_args,
|
|
||||||
Choices,
|
|
||||||
Message,
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
TextChoices,
|
||||||
|
TextCompletionResponse,
|
||||||
|
TextCompletionStreamWrapper,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
get_secret,
|
||||||
|
read_config_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
|
@ -1845,43 +1868,7 @@ def completion(
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "gemini":
|
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
||||||
gemini_api_key = (
|
|
||||||
api_key
|
|
||||||
or get_secret("GEMINI_API_KEY")
|
|
||||||
or get_secret("PALM_API_KEY") # older palm api key should also work
|
|
||||||
or litellm.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
# palm does not support streaming as yet :(
|
|
||||||
model_response = gemini.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
api_key=gemini_api_key,
|
|
||||||
logging_obj=logging,
|
|
||||||
acompletion=acompletion,
|
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"stream" in optional_params
|
|
||||||
and optional_params["stream"] == True
|
|
||||||
and acompletion == False
|
|
||||||
):
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
iter(model_response),
|
|
||||||
model,
|
|
||||||
custom_llm_provider="gemini",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
response = model_response
|
|
||||||
elif custom_llm_provider == "vertex_ai_beta":
|
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
optional_params.pop("vertex_project", None)
|
optional_params.pop("vertex_project", None)
|
||||||
or optional_params.pop("vertex_ai_project", None)
|
or optional_params.pop("vertex_ai_project", None)
|
||||||
|
@ -1899,6 +1886,14 @@ def completion(
|
||||||
or optional_params.pop("vertex_ai_credentials", None)
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
or get_secret("VERTEXAI_CREDENTIALS")
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gemini_api_key = (
|
||||||
|
api_key
|
||||||
|
or get_secret("GEMINI_API_KEY")
|
||||||
|
or get_secret("PALM_API_KEY") # older palm api key should also work
|
||||||
|
or litellm.api_key
|
||||||
|
)
|
||||||
|
|
||||||
new_params = deepcopy(optional_params)
|
new_params = deepcopy(optional_params)
|
||||||
response = vertex_chat_completion.completion( # type: ignore
|
response = vertex_chat_completion.completion( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1912,9 +1907,11 @@ def completion(
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
vertex_project=vertex_ai_project,
|
vertex_project=vertex_ai_project,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
|
gemini_api_key=gemini_api_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
# conftest.py
|
# conftest.py
|
||||||
|
|
||||||
import pytest, sys, os
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -18,6 +21,7 @@ def setup_and_teardown():
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the project directory to the system path
|
) # Adds the project directory to the system path
|
||||||
|
print("LITELLM_LOG - {}".format(os.getenv("LITELLM_LOG")))
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,26 @@
|
||||||
import sys, os, json
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os, io
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
from litellm import RateLimitError
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
|
||||||
|
|
||||||
# litellm.num_retries =3
|
# litellm.num_retries =3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
@ -1472,7 +1478,9 @@ def test_ollama_image():
|
||||||
data is untouched.
|
data is untouched.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import io, base64
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
def mock_post(url, **kwargs):
|
def mock_post(url, **kwargs):
|
||||||
|
|
|
@ -6972,7 +6972,9 @@ def exception_type(
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
if hasattr(original_exception, "request"):
|
if hasattr(original_exception, "request"):
|
||||||
raise APIConnectionError(
|
raise APIConnectionError(
|
||||||
message=f"{str(original_exception)}",
|
message="{}\n{}".format(
|
||||||
|
str(original_exception), traceback.format_exc()
|
||||||
|
),
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request,
|
request=original_exception.request,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue