fix case when gemini is used

This commit is contained in:
Ishaan Jaff 2024-09-10 17:06:45 -07:00
parent 1c6f8b1be2
commit 96fa9d46f5
11 changed files with 47 additions and 13 deletions

View file

@ -38,6 +38,7 @@ class GCSBucketBase(CustomLogger):
await vertex_chat_completion._ensure_access_token_async( await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json, credentials=self.path_service_account_json,
project_id=None, project_id=None,
custom_llm_provider="vertex_ai",
) )
) )
@ -66,6 +67,7 @@ class GCSBucketBase(CustomLogger):
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token( _auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
credentials=self.path_service_account_json, credentials=self.path_service_account_json,
project_id=None, project_id=None,
custom_llm_provider="vertex_ai",
) )
auth_header, _ = vertex_chat_completion._get_token_and_url( auth_header, _ = vertex_chat_completion._get_token_and_url(

View file

@ -184,10 +184,10 @@ class VertexFineTuningAPI(VertexLLM):
verbose_logger.debug( verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data "creating fine tuning job, args= %s", create_fine_tuning_job_data
) )
_auth_header, vertex_project = self._ensure_access_token( _auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, credentials=vertex_credentials,
project_id=vertex_project, project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
) )
auth_header, _ = self._get_token_and_url( auth_header, _ = self._get_token_and_url(
@ -257,10 +257,10 @@ class VertexFineTuningAPI(VertexLLM):
vertex_credentials: str, vertex_credentials: str,
request_route: str, request_route: str,
): ):
_auth_header, vertex_project = await self._ensure_access_token_async( _auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials, credentials=vertex_credentials,
project_id=vertex_project, project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
) )
auth_header, _ = self._get_token_and_url( auth_header, _ = self._get_token_and_url(
model="", model="",

View file

@ -1079,7 +1079,9 @@ class VertexLLM(VertexBase):
) )
_auth_header, vertex_project = await self._ensure_access_token_async( _auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
) )
auth_header, api_base = self._get_token_and_url( auth_header, api_base = self._get_token_and_url(
@ -1157,7 +1159,9 @@ class VertexLLM(VertexBase):
) )
_auth_header, vertex_project = await self._ensure_access_token_async( _auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
) )
auth_header, api_base = self._get_token_and_url( auth_header, api_base = self._get_token_and_url(
@ -1310,7 +1314,9 @@ class VertexLLM(VertexBase):
) )
_auth_header, vertex_project = self._ensure_access_token( _auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
) )
auth_header, url = self._get_token_and_url( auth_header, url = self._get_token_and_url(

View file

@ -46,6 +46,7 @@ class GoogleBatchEmbeddings(VertexLLM):
_auth_header, vertex_project = self._ensure_access_token( _auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, credentials=vertex_credentials,
project_id=vertex_project, project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
) )
auth_header, url = self._get_token_and_url( auth_header, url = self._get_token_and_url(

View file

@ -82,7 +82,9 @@ class VertexImageGeneration(VertexLLM):
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token( auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
) )
optional_params = optional_params or { optional_params = optional_params or {
"sampleCount": 1 "sampleCount": 1
@ -180,7 +182,9 @@ class VertexImageGeneration(VertexLLM):
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
""" """
auth_header, _ = self._ensure_access_token( auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
) )
optional_params = optional_params or { optional_params = optional_params or {
"sampleCount": 1 "sampleCount": 1

View file

@ -47,6 +47,7 @@ class VertexMultimodalEmbedding(VertexLLM):
_auth_header, vertex_project = self._ensure_access_token( _auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, credentials=vertex_credentials,
project_id=vertex_project, project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
) )
auth_header, url = self._get_token_and_url( auth_header, url = self._get_token_and_url(

View file

@ -65,10 +65,10 @@ class VertexTextToSpeechAPI(VertexLLM):
import base64 import base64
####### Authenticate with Vertex AI ######## ####### Authenticate with Vertex AI ########
_auth_header, vertex_project = self._ensure_access_token( _auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, credentials=vertex_credentials,
project_id=vertex_project, project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
) )
auth_header, _ = self._get_token_and_url( auth_header, _ = self._get_token_and_url(

View file

@ -292,7 +292,9 @@ def completion(
vertex_httpx_logic = VertexLLM() vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token( access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
) )
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()

View file

@ -105,7 +105,9 @@ class VertexAIPartnerModels(BaseLLM):
vertex_httpx_logic = VertexLLM() vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token( access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
) )
openai_like_chat_completions = DatabricksChatCompletion() openai_like_chat_completions = DatabricksChatCompletion()

View file

@ -109,11 +109,18 @@ class VertexBase(BaseLLM):
credentials.refresh(Request()) credentials.refresh(Request())
def _ensure_access_token( def _ensure_access_token(
self, credentials: Optional[str], project_id: Optional[str] self,
credentials: Optional[str],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
) -> Tuple[str, str]: ) -> Tuple[str, str]:
""" """
Returns auth token and project id Returns auth token and project id
""" """
if custom_llm_provider == "gemini":
return "", ""
if self.access_token is not None: if self.access_token is not None:
if project_id is not None: if project_id is not None:
return self.access_token, project_id return self.access_token, project_id
@ -222,11 +229,18 @@ class VertexBase(BaseLLM):
return auth_header, url return auth_header, url
async def _ensure_access_token_async( async def _ensure_access_token_async(
self, credentials: Optional[str], project_id: Optional[str] self,
credentials: Optional[str],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
) -> Tuple[str, str]: ) -> Tuple[str, str]:
""" """
Async version of _ensure_access_token Async version of _ensure_access_token
""" """
if custom_llm_provider == "gemini":
return "", ""
if self.access_token is not None: if self.access_token is not None:
if project_id is not None: if project_id is not None:
return self.access_token, project_id return self.access_token, project_id

View file

@ -152,7 +152,9 @@ async def vertex_proxy_route(
_auth_header, vertex_project = ( _auth_header, vertex_project = (
await vertex_fine_tuning_apis_instance._ensure_access_token_async( await vertex_fine_tuning_apis_instance._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
) )
) )