feat(main.py): initial commit - refactoring google ai studio to just use vertex httpx

Uses the same calling logic for google ai studio/vertex ai. Simplifies logic, gives google ai studio integration all of vertex
ai features.
This commit is contained in:
Krrish Dholakia 2024-06-17 13:31:46 -07:00
parent fcea4c22ad
commit be66800a98
7 changed files with 216 additions and 137 deletions

View file

@ -13,7 +13,10 @@ from litellm._logging import (
verbose_logger,
json_logs,
_turn_on_json,
log_level,
)
from litellm.proxy._types import (
KeyManagementSystem,
KeyManagementSettings,

View file

@ -1,6 +1,8 @@
import logging, os, json
from logging import Formatter
import json
import logging
import os
import traceback
from logging import Formatter
set_verbose = False

View file

@ -1,41 +1,47 @@
# What is this?
## httpx client for vertex ai calls
## Initial implementation - covers gemini + image gen calls
from functools import partial
import os, types
import inspect
import json
from enum import Enum
import requests # type: ignore
import os
import time
from typing import Callable, Optional, Union, List, Any, Tuple
import types
import uuid
from enum import Enum
from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
from litellm.types.llms.vertex_ai import (
ContentType,
SystemInstructions,
PartType,
RequestBody,
GenerateContentResponseBody,
FunctionCallingConfig,
FunctionDeclaration,
Tools,
ToolConfig,
GenerationConfig,
)
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
from litellm.types.utils import GenericStreamingChunk
from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionResponseMessage,
ChatCompletionUsageBlock,
)
from litellm.types.llms.vertex_ai import (
ContentType,
FunctionCallingConfig,
FunctionDeclaration,
GenerateContentResponseBody,
GenerationConfig,
PartType,
RequestBody,
SystemInstructions,
ToolConfig,
Tools,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
class VertexGeminiConfig:
@ -414,9 +420,11 @@ class VertexLLM(BaseLLM):
def load_auth(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[Any, str]:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.credentials import Credentials # type: ignore[import-untyped]
import google.auth as google_auth
from google.auth.credentials import Credentials # type: ignore[import-untyped]
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
if credentials is not None and isinstance(credentials, str):
import google.oauth2.service_account
@ -449,7 +457,9 @@ class VertexLLM(BaseLLM):
return creds, project_id
def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
credentials.refresh(Request())
@ -482,6 +492,50 @@ class VertexLLM(BaseLLM):
return self._credentials.token, self.project_id
def _get_token_and_url(
self,
model: str,
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
)
else:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
return auth_header, url
async def async_streaming(
self,
model: str,
@ -574,6 +628,9 @@ class VertexLLM(BaseLLM):
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
encoding,
logging_obj,
optional_params: dict,
@ -582,20 +639,23 @@ class VertexLLM(BaseLLM):
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
gemini_api_key: Optional[str],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
### SET RUNTIME ENDPOINT ###
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:generateContent"
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=stream,
custom_llm_provider=custom_llm_provider,
)
## TRANSFORMATION ##
# Separate system prompt from rest of message
@ -609,14 +669,16 @@ class VertexLLM(BaseLLM):
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
system_instructions = SystemInstructions(parts=system_content_blocks)
content = _gemini_convert_messages_with_history(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params
)
data = RequestBody(system_instruction=system_instructions, contents=content)
data = RequestBody(contents=content)
if len(system_content_blocks) > 0:
system_instructions = SystemInstructions(parts=system_content_blocks)
data["system_instruction"] = system_instructions
if tools is not None:
data["tools"] = tools
if tool_choice is not None:
@ -626,8 +688,9 @@ class VertexLLM(BaseLLM):
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
## LOGGING
logging_obj.pre_call(

View file

@ -7,107 +7,130 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union, BinaryIO
from typing_extensions import overload
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
import asyncio
import contextvars
import datetime
import inspect
import json
import os
import random
import sys
import threading
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
from typing import (
Any,
BinaryIO,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
import dotenv
import httpx
import openai
import tiktoken
from typing_extensions import overload
import litellm
from ._logging import verbose_logger
from litellm import ( # type: ignore
Logging,
client,
exception_type,
get_optional_params,
get_litellm_params,
Logging,
get_optional_params,
)
from litellm.utils import (
get_secret,
CustomStreamWrapper,
read_config_args,
completion_with_fallbacks,
get_llm_provider,
get_api_key,
mock_completion_streaming_obj,
Usage,
async_mock_completion_streaming_obj,
completion_with_fallbacks,
convert_to_model_response_object,
token_counter,
create_pretrained_tokenizer,
create_tokenizer,
Usage,
get_api_key,
get_llm_provider,
get_optional_params_embeddings,
get_optional_params_image_gen,
get_secret,
mock_completion_streaming_obj,
read_config_args,
supports_httpx_timeout,
token_counter,
)
from ._logging import verbose_logger
from .caching import disable_cache, enable_cache, update_cache
from .llms import (
anthropic_text,
together_ai,
ai21,
sagemaker,
bedrock,
triton,
huggingface_restapi,
replicate,
aleph_alpha,
nlp_cloud,
anthropic_text,
baseten,
vllm,
ollama,
ollama_chat,
cloudflare,
bedrock,
clarifai,
cloudflare,
cohere,
cohere_chat,
petals,
gemini,
huggingface_restapi,
maritalk,
nlp_cloud,
ollama,
ollama_chat,
oobabooga,
openrouter,
palm,
gemini,
petals,
replicate,
sagemaker,
together_ai,
triton,
vertex_ai,
vertex_ai_anthropic,
maritalk,
vllm,
watsonx,
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
from .llms.databricks import DatabricksChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
function_call_prompt,
map_system_message_pt,
prompt_factory,
)
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache
from .llms.triton import TritonChatCompletion
from .llms.vertex_httpx import VertexLLM
from .types.llms.openai import HttpxBinaryResponseContent
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
get_secret,
Choices,
CustomStreamWrapper,
TextCompletionStreamWrapper,
ModelResponse,
TextCompletionResponse,
TextChoices,
EmbeddingResponse,
ImageResponse,
read_config_args,
Choices,
Message,
ModelResponse,
TextChoices,
TextCompletionResponse,
TextCompletionStreamWrapper,
TranscriptionResponse,
get_secret,
read_config_args,
)
####### ENVIRONMENT VARIABLES ###################
@ -1845,43 +1868,7 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "gemini":
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
or get_secret("PALM_API_KEY") # older palm api key should also work
or litellm.api_key
)
# palm does not support streaming as yet :(
model_response = gemini.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=gemini_api_key,
logging_obj=logging,
acompletion=acompletion,
custom_prompt_dict=custom_prompt_dict,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
iter(model_response),
model,
custom_llm_provider="gemini",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai_beta":
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
@ -1899,6 +1886,14 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
or get_secret("PALM_API_KEY") # older palm api key should also work
or litellm.api_key
)
new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore
model=model,
@ -1912,9 +1907,11 @@ def completion(
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
gemini_api_key=gemini_api_key,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
)
elif custom_llm_provider == "vertex_ai":

View file

@ -1,7 +1,10 @@
# conftest.py
import pytest, sys, os
import importlib
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
@ -18,6 +21,7 @@ def setup_and_teardown():
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
print("LITELLM_LOG - {}".format(os.getenv("LITELLM_LOG")))
import litellm
from litellm import Router

View file

@ -1,20 +1,26 @@
import sys, os, json
import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from unittest.mock import MagicMock, patch
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from unittest.mock import patch, MagicMock
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
# litellm.num_retries =3
litellm.cache = None
@ -1472,7 +1478,9 @@ def test_ollama_image():
data is untouched.
"""
import io, base64
import base64
import io
from PIL import Image
def mock_post(url, **kwargs):

View file

@ -6972,7 +6972,9 @@ def exception_type(
exception_mapping_worked = True
if hasattr(original_exception, "request"):
raise APIConnectionError(
message=f"{str(original_exception)}",
message="{}\n{}".format(
str(original_exception), traceback.format_exc()
),
llm_provider=custom_llm_provider,
model=model,
request=original_exception.request,
@ -7186,7 +7188,7 @@ def get_secret(
else:
raise ValueError(
f"Google KMS requires the encrypted secret to be encoded in base64"
)#fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
response = client.decrypt(
request={
"name": litellm._google_kms_resource_name,