forked from phoenix/litellm-mirror
fix case when gemini is used
This commit is contained in:
parent
1c6f8b1be2
commit
96fa9d46f5
11 changed files with 47 additions and 13 deletions
|
@ -38,6 +38,7 @@ class GCSBucketBase(CustomLogger):
|
|||
await vertex_chat_completion._ensure_access_token_async(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -66,6 +67,7 @@ class GCSBucketBase(CustomLogger):
|
|||
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
|
|
|
@ -184,10 +184,10 @@ class VertexFineTuningAPI(VertexLLM):
|
|||
verbose_logger.debug(
|
||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
|
@ -257,10 +257,10 @@ class VertexFineTuningAPI(VertexLLM):
|
|||
vertex_credentials: str,
|
||||
request_route: str,
|
||||
):
|
||||
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
|
|
|
@ -1079,7 +1079,9 @@ class VertexLLM(VertexBase):
|
|||
)
|
||||
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
|
@ -1157,7 +1159,9 @@ class VertexLLM(VertexBase):
|
|||
)
|
||||
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
|
@ -1310,7 +1314,9 @@ class VertexLLM(VertexBase):
|
|||
)
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
|
|
|
@ -46,6 +46,7 @@ class GoogleBatchEmbeddings(VertexLLM):
|
|||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
|
|
|
@ -82,7 +82,9 @@ class VertexImageGeneration(VertexLLM):
|
|||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
|
@ -180,7 +182,9 @@ class VertexImageGeneration(VertexLLM):
|
|||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||
"""
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
|
|
|
@ -47,6 +47,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
|||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
|
|
|
@ -65,10 +65,10 @@ class VertexTextToSpeechAPI(VertexLLM):
|
|||
import base64
|
||||
|
||||
####### Authenticate with Vertex AI ########
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
|
|
|
@ -292,7 +292,9 @@ def completion(
|
|||
vertex_httpx_logic = VertexLLM()
|
||||
|
||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
|
|
@ -105,7 +105,9 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
vertex_httpx_logic = VertexLLM()
|
||||
|
||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
openai_like_chat_completions = DatabricksChatCompletion()
|
||||
|
|
|
@ -109,11 +109,18 @@ class VertexBase(BaseLLM):
|
|||
credentials.refresh(Request())
|
||||
|
||||
def _ensure_access_token(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
self,
|
||||
credentials: Optional[str],
|
||||
project_id: Optional[str],
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns auth token and project id
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
return "", ""
|
||||
if self.access_token is not None:
|
||||
if project_id is not None:
|
||||
return self.access_token, project_id
|
||||
|
@ -222,11 +229,18 @@ class VertexBase(BaseLLM):
|
|||
return auth_header, url
|
||||
|
||||
async def _ensure_access_token_async(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
self,
|
||||
credentials: Optional[str],
|
||||
project_id: Optional[str],
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Async version of _ensure_access_token
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
return "", ""
|
||||
if self.access_token is not None:
|
||||
if project_id is not None:
|
||||
return self.access_token, project_id
|
||||
|
|
|
@ -152,7 +152,9 @@ async def vertex_proxy_route(
|
|||
|
||||
_auth_header, vertex_project = (
|
||||
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue