feat(main.py): initial commit - refactoring google ai studio to just use vertex httpx

Uses the same calling logic for google ai studio/vertex ai. Simplifies logic, gives google ai studio integration all of vertex
ai features.
This commit is contained in:
Krrish Dholakia 2024-06-17 13:31:46 -07:00
parent fcea4c22ad
commit be66800a98
7 changed files with 216 additions and 137 deletions

View file

@ -7,107 +7,130 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union, BinaryIO
from typing_extensions import overload
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
import asyncio
import contextvars
import datetime
import inspect
import json
import os
import random
import sys
import threading
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
from typing import (
Any,
BinaryIO,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
import dotenv
import httpx
import openai
import tiktoken
from typing_extensions import overload
import litellm
from ._logging import verbose_logger
from litellm import ( # type: ignore
Logging,
client,
exception_type,
get_optional_params,
get_litellm_params,
Logging,
get_optional_params,
)
from litellm.utils import (
get_secret,
CustomStreamWrapper,
read_config_args,
completion_with_fallbacks,
get_llm_provider,
get_api_key,
mock_completion_streaming_obj,
Usage,
async_mock_completion_streaming_obj,
completion_with_fallbacks,
convert_to_model_response_object,
token_counter,
create_pretrained_tokenizer,
create_tokenizer,
Usage,
get_api_key,
get_llm_provider,
get_optional_params_embeddings,
get_optional_params_image_gen,
get_secret,
mock_completion_streaming_obj,
read_config_args,
supports_httpx_timeout,
token_counter,
)
from ._logging import verbose_logger
from .caching import disable_cache, enable_cache, update_cache
from .llms import (
anthropic_text,
together_ai,
ai21,
sagemaker,
bedrock,
triton,
huggingface_restapi,
replicate,
aleph_alpha,
nlp_cloud,
anthropic_text,
baseten,
vllm,
ollama,
ollama_chat,
cloudflare,
bedrock,
clarifai,
cloudflare,
cohere,
cohere_chat,
petals,
gemini,
huggingface_restapi,
maritalk,
nlp_cloud,
ollama,
ollama_chat,
oobabooga,
openrouter,
palm,
gemini,
petals,
replicate,
sagemaker,
together_ai,
triton,
vertex_ai,
vertex_ai_anthropic,
maritalk,
vllm,
watsonx,
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
from .llms.databricks import DatabricksChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
function_call_prompt,
map_system_message_pt,
prompt_factory,
)
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache
from .llms.triton import TritonChatCompletion
from .llms.vertex_httpx import VertexLLM
from .types.llms.openai import HttpxBinaryResponseContent
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
get_secret,
Choices,
CustomStreamWrapper,
TextCompletionStreamWrapper,
ModelResponse,
TextCompletionResponse,
TextChoices,
EmbeddingResponse,
ImageResponse,
read_config_args,
Choices,
Message,
ModelResponse,
TextChoices,
TextCompletionResponse,
TextCompletionStreamWrapper,
TranscriptionResponse,
get_secret,
read_config_args,
)
####### ENVIRONMENT VARIABLES ###################
@ -1845,43 +1868,7 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "gemini":
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
or get_secret("PALM_API_KEY") # older palm api key should also work
or litellm.api_key
)
# palm does not support streaming as yet :(
model_response = gemini.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=gemini_api_key,
logging_obj=logging,
acompletion=acompletion,
custom_prompt_dict=custom_prompt_dict,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
iter(model_response),
model,
custom_llm_provider="gemini",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai_beta":
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
@ -1899,6 +1886,14 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
or get_secret("PALM_API_KEY") # older palm api key should also work
or litellm.api_key
)
new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore
model=model,
@ -1912,9 +1907,11 @@ def completion(
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
gemini_api_key=gemini_api_key,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
)
elif custom_llm_provider == "vertex_ai":