fix case when gemini is used

This commit is contained in:
Ishaan Jaff 2024-09-10 17:06:45 -07:00
parent 1c6f8b1be2
commit 96fa9d46f5
11 changed files with 47 additions and 13 deletions

View file

@ -38,6 +38,7 @@ class GCSBucketBase(CustomLogger):
await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
)
@ -66,6 +67,7 @@ class GCSBucketBase(CustomLogger):
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
credentials=self.path_service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_chat_completion._get_token_and_url(

View file

@ -184,10 +184,10 @@ class VertexFineTuningAPI(VertexLLM):
verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = self._get_token_and_url(
@ -257,10 +257,10 @@ class VertexFineTuningAPI(VertexLLM):
vertex_credentials: str,
request_route: str,
):
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = self._get_token_and_url(
model="",

View file

@ -1079,7 +1079,9 @@ class VertexLLM(VertexBase):
)
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, api_base = self._get_token_and_url(
@ -1157,7 +1159,9 @@ class VertexLLM(VertexBase):
)
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, api_base = self._get_token_and_url(
@ -1310,7 +1314,9 @@ class VertexLLM(VertexBase):
)
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, url = self._get_token_and_url(

View file

@ -46,6 +46,7 @@ class GoogleBatchEmbeddings(VertexLLM):
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, url = self._get_token_and_url(

View file

@ -82,7 +82,9 @@ class VertexImageGeneration(VertexLLM):
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
optional_params = optional_params or {
"sampleCount": 1
@ -180,7 +182,9 @@ class VertexImageGeneration(VertexLLM):
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
optional_params = optional_params or {
"sampleCount": 1

View file

@ -47,6 +47,7 @@ class VertexMultimodalEmbedding(VertexLLM):
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, url = self._get_token_and_url(

View file

@ -65,10 +65,10 @@ class VertexTextToSpeechAPI(VertexLLM):
import base64
####### Authenticate with Vertex AI ########
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = self._get_token_and_url(

View file

@ -292,7 +292,9 @@ def completion(
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
anthropic_chat_completions = AnthropicChatCompletion()

View file

@ -105,7 +105,9 @@ class VertexAIPartnerModels(BaseLLM):
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
openai_like_chat_completions = DatabricksChatCompletion()

View file

@ -109,11 +109,18 @@ class VertexBase(BaseLLM):
credentials.refresh(Request())
def _ensure_access_token(
self, credentials: Optional[str], project_id: Optional[str]
self,
credentials: Optional[str],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
) -> Tuple[str, str]:
"""
Returns auth token and project id
"""
if custom_llm_provider == "gemini":
return "", ""
if self.access_token is not None:
if project_id is not None:
return self.access_token, project_id
@ -222,11 +229,18 @@ class VertexBase(BaseLLM):
return auth_header, url
async def _ensure_access_token_async(
self, credentials: Optional[str], project_id: Optional[str]
self,
credentials: Optional[str],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
) -> Tuple[str, str]:
"""
Async version of _ensure_access_token
"""
if custom_llm_provider == "gemini":
return "", ""
if self.access_token is not None:
if project_id is not None:
return self.access_token, project_id

View file

@ -152,7 +152,9 @@ async def vertex_proxy_route(
_auth_header, vertex_project = (
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
)