forked from phoenix/litellm-mirror
feat(main.py): initial commit - refactoring google ai studio to just use vertex httpx
Uses the same calling logic for google ai studio/vertex ai. Simplifies logic, gives google ai studio integration all of vertex ai features.
This commit is contained in:
parent
fcea4c22ad
commit
be66800a98
7 changed files with 216 additions and 137 deletions
|
@ -13,7 +13,10 @@ from litellm._logging import (
|
|||
verbose_logger,
|
||||
json_logs,
|
||||
_turn_on_json,
|
||||
log_level,
|
||||
)
|
||||
|
||||
|
||||
from litellm.proxy._types import (
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import logging, os, json
|
||||
from logging import Formatter
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from logging import Formatter
|
||||
|
||||
set_verbose = False
|
||||
|
||||
|
|
|
@ -1,41 +1,47 @@
|
|||
# What is this?
|
||||
## httpx client for vertex ai calls
|
||||
## Initial implementation - covers gemini + image gen calls
|
||||
from functools import partial
|
||||
import os, types
|
||||
import inspect
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests # type: ignore
|
||||
import os
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List, Any, Tuple
|
||||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.litellm_logging
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect # type: ignore
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
ContentType,
|
||||
SystemInstructions,
|
||||
PartType,
|
||||
RequestBody,
|
||||
GenerateContentResponseBody,
|
||||
FunctionCallingConfig,
|
||||
FunctionDeclaration,
|
||||
Tools,
|
||||
ToolConfig,
|
||||
GenerationConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionUsageBlock,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
ContentType,
|
||||
FunctionCallingConfig,
|
||||
FunctionDeclaration,
|
||||
GenerateContentResponseBody,
|
||||
GenerationConfig,
|
||||
PartType,
|
||||
RequestBody,
|
||||
SystemInstructions,
|
||||
ToolConfig,
|
||||
Tools,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class VertexGeminiConfig:
|
||||
|
@ -414,9 +420,11 @@ class VertexLLM(BaseLLM):
|
|||
def load_auth(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[Any, str]:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
import google.auth as google_auth
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
if credentials is not None and isinstance(credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
@ -449,7 +457,9 @@ class VertexLLM(BaseLLM):
|
|||
return creds, project_id
|
||||
|
||||
def refresh_auth(self, credentials: Any) -> None:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
|
@ -482,6 +492,50 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
return self._credentials.token, self.project_id
|
||||
|
||||
def _get_token_and_url(
|
||||
self,
|
||||
model: str,
|
||||
gemini_api_key: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
||||
Handles logic if it's google ai studio vs. vertex ai.
|
||||
|
||||
Returns
|
||||
token, url
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
_gemini_model_name = "models/{}".format(model)
|
||||
auth_header = None
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
|
||||
url = (
|
||||
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
)
|
||||
else:
|
||||
auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
return auth_header, url
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -574,6 +628,9 @@ class VertexLLM(BaseLLM):
|
|||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
|
@ -582,20 +639,23 @@ class VertexLLM(BaseLLM):
|
|||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
gemini_api_key: Optional[str],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
# Separate system prompt from rest of message
|
||||
|
@ -609,14 +669,16 @@ class VertexLLM(BaseLLM):
|
|||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||
**optional_params
|
||||
)
|
||||
data = RequestBody(system_instruction=system_instructions, contents=content)
|
||||
data = RequestBody(contents=content)
|
||||
if len(system_content_blocks) > 0:
|
||||
system_instructions = SystemInstructions(parts=system_content_blocks)
|
||||
data["system_instruction"] = system_instructions
|
||||
if tools is not None:
|
||||
data["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
|
@ -626,8 +688,9 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
if auth_header is not None:
|
||||
headers["Authorization"] = f"Bearer {auth_header}"
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
|
173
litellm/main.py
173
litellm/main.py
|
@ -7,107 +7,130 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||
from typing import Any, Literal, Union, BinaryIO
|
||||
from typing_extensions import overload
|
||||
from functools import partial
|
||||
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
import asyncio
|
||||
import contextvars
|
||||
import datetime
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
import dotenv
|
||||
import httpx
|
||||
import openai
|
||||
import tiktoken
|
||||
from typing_extensions import overload
|
||||
|
||||
import litellm
|
||||
from ._logging import verbose_logger
|
||||
from litellm import ( # type: ignore
|
||||
Logging,
|
||||
client,
|
||||
exception_type,
|
||||
get_optional_params,
|
||||
get_litellm_params,
|
||||
Logging,
|
||||
get_optional_params,
|
||||
)
|
||||
from litellm.utils import (
|
||||
get_secret,
|
||||
CustomStreamWrapper,
|
||||
read_config_args,
|
||||
completion_with_fallbacks,
|
||||
get_llm_provider,
|
||||
get_api_key,
|
||||
mock_completion_streaming_obj,
|
||||
Usage,
|
||||
async_mock_completion_streaming_obj,
|
||||
completion_with_fallbacks,
|
||||
convert_to_model_response_object,
|
||||
token_counter,
|
||||
create_pretrained_tokenizer,
|
||||
create_tokenizer,
|
||||
Usage,
|
||||
get_api_key,
|
||||
get_llm_provider,
|
||||
get_optional_params_embeddings,
|
||||
get_optional_params_image_gen,
|
||||
get_secret,
|
||||
mock_completion_streaming_obj,
|
||||
read_config_args,
|
||||
supports_httpx_timeout,
|
||||
token_counter,
|
||||
)
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from .caching import disable_cache, enable_cache, update_cache
|
||||
from .llms import (
|
||||
anthropic_text,
|
||||
together_ai,
|
||||
ai21,
|
||||
sagemaker,
|
||||
bedrock,
|
||||
triton,
|
||||
huggingface_restapi,
|
||||
replicate,
|
||||
aleph_alpha,
|
||||
nlp_cloud,
|
||||
anthropic_text,
|
||||
baseten,
|
||||
vllm,
|
||||
ollama,
|
||||
ollama_chat,
|
||||
cloudflare,
|
||||
bedrock,
|
||||
clarifai,
|
||||
cloudflare,
|
||||
cohere,
|
||||
cohere_chat,
|
||||
petals,
|
||||
gemini,
|
||||
huggingface_restapi,
|
||||
maritalk,
|
||||
nlp_cloud,
|
||||
ollama,
|
||||
ollama_chat,
|
||||
oobabooga,
|
||||
openrouter,
|
||||
palm,
|
||||
gemini,
|
||||
petals,
|
||||
replicate,
|
||||
sagemaker,
|
||||
together_ai,
|
||||
triton,
|
||||
vertex_ai,
|
||||
vertex_ai_anthropic,
|
||||
maritalk,
|
||||
vllm,
|
||||
watsonx,
|
||||
)
|
||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
from .llms.azure import AzureChatCompletion
|
||||
from .llms.databricks import DatabricksChatCompletion
|
||||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.anthropic import AnthropicChatCompletion
|
||||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.azure import AzureChatCompletion
|
||||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.databricks import DatabricksChatCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
function_call_prompt,
|
||||
map_system_message_pt,
|
||||
prompt_factory,
|
||||
)
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, List, Optional, Dict, Union, Mapping
|
||||
from .caching import enable_cache, disable_cache, update_cache
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
get_secret,
|
||||
Choices,
|
||||
CustomStreamWrapper,
|
||||
TextCompletionStreamWrapper,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
TextChoices,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
read_config_args,
|
||||
Choices,
|
||||
Message,
|
||||
ModelResponse,
|
||||
TextChoices,
|
||||
TextCompletionResponse,
|
||||
TextCompletionStreamWrapper,
|
||||
TranscriptionResponse,
|
||||
get_secret,
|
||||
read_config_args,
|
||||
)
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
|
@ -1845,43 +1868,7 @@ def completion(
|
|||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "gemini":
|
||||
gemini_api_key = (
|
||||
api_key
|
||||
or get_secret("GEMINI_API_KEY")
|
||||
or get_secret("PALM_API_KEY") # older palm api key should also work
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
# palm does not support streaming as yet :(
|
||||
model_response = gemini.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=gemini_api_key,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
):
|
||||
response = CustomStreamWrapper(
|
||||
iter(model_response),
|
||||
model,
|
||||
custom_llm_provider="gemini",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
or optional_params.pop("vertex_ai_project", None)
|
||||
|
@ -1899,6 +1886,14 @@ def completion(
|
|||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
gemini_api_key = (
|
||||
api_key
|
||||
or get_secret("GEMINI_API_KEY")
|
||||
or get_secret("PALM_API_KEY") # older palm api key should also work
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
new_params = deepcopy(optional_params)
|
||||
response = vertex_chat_completion.completion( # type: ignore
|
||||
model=model,
|
||||
|
@ -1912,9 +1907,11 @@ def completion(
|
|||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
gemini_api_key=gemini_api_key,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
# conftest.py
|
||||
|
||||
import pytest, sys, os
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -18,6 +21,7 @@ def setup_and_teardown():
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the project directory to the system path
|
||||
print("LITELLM_LOG - {}".format(os.getenv("LITELLM_LOG")))
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
|
|
|
@ -1,20 +1,26 @@
|
|||
import sys, os, json
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
from unittest.mock import patch, MagicMock
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
# litellm.num_retries =3
|
||||
litellm.cache = None
|
||||
|
@ -1472,7 +1478,9 @@ def test_ollama_image():
|
|||
data is untouched.
|
||||
"""
|
||||
|
||||
import io, base64
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
def mock_post(url, **kwargs):
|
||||
|
|
|
@ -6972,7 +6972,9 @@ def exception_type(
|
|||
exception_mapping_worked = True
|
||||
if hasattr(original_exception, "request"):
|
||||
raise APIConnectionError(
|
||||
message=f"{str(original_exception)}",
|
||||
message="{}\n{}".format(
|
||||
str(original_exception), traceback.format_exc()
|
||||
),
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
request=original_exception.request,
|
||||
|
@ -7186,7 +7188,7 @@ def get_secret(
|
|||
else:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be encoded in base64"
|
||||
)#fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
|
||||
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
|
||||
response = client.decrypt(
|
||||
request={
|
||||
"name": litellm._google_kms_resource_name,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue