forked from phoenix/litellm-mirror
fix case when gemini is used
This commit is contained in:
parent
1c6f8b1be2
commit
96fa9d46f5
11 changed files with 47 additions and 13 deletions
|
@ -38,6 +38,7 @@ class GCSBucketBase(CustomLogger):
|
||||||
await vertex_chat_completion._ensure_access_token_async(
|
await vertex_chat_completion._ensure_access_token_async(
|
||||||
credentials=self.path_service_account_json,
|
credentials=self.path_service_account_json,
|
||||||
project_id=None,
|
project_id=None,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -66,6 +67,7 @@ class GCSBucketBase(CustomLogger):
|
||||||
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
||||||
credentials=self.path_service_account_json,
|
credentials=self.path_service_account_json,
|
||||||
project_id=None,
|
project_id=None,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||||
|
|
|
@ -184,10 +184,10 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||||
)
|
)
|
||||||
|
|
||||||
_auth_header, vertex_project = self._ensure_access_token(
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
credentials=vertex_credentials,
|
credentials=vertex_credentials,
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_header, _ = self._get_token_and_url(
|
auth_header, _ = self._get_token_and_url(
|
||||||
|
@ -257,10 +257,10 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
vertex_credentials: str,
|
vertex_credentials: str,
|
||||||
request_route: str,
|
request_route: str,
|
||||||
):
|
):
|
||||||
|
|
||||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||||
credentials=vertex_credentials,
|
credentials=vertex_credentials,
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
)
|
)
|
||||||
auth_header, _ = self._get_token_and_url(
|
auth_header, _ = self._get_token_and_url(
|
||||||
model="",
|
model="",
|
||||||
|
|
|
@ -1079,7 +1079,9 @@ class VertexLLM(VertexBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_header, api_base = self._get_token_and_url(
|
auth_header, api_base = self._get_token_and_url(
|
||||||
|
@ -1157,7 +1159,9 @@ class VertexLLM(VertexBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_header, api_base = self._get_token_and_url(
|
auth_header, api_base = self._get_token_and_url(
|
||||||
|
@ -1310,7 +1314,9 @@ class VertexLLM(VertexBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
_auth_header, vertex_project = self._ensure_access_token(
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_header, url = self._get_token_and_url(
|
auth_header, url = self._get_token_and_url(
|
||||||
|
|
|
@ -46,6 +46,7 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||||
_auth_header, vertex_project = self._ensure_access_token(
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
credentials=vertex_credentials,
|
credentials=vertex_credentials,
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_header, url = self._get_token_and_url(
|
auth_header, url = self._get_token_and_url(
|
||||||
|
|
|
@ -82,7 +82,9 @@ class VertexImageGeneration(VertexLLM):
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||||
|
|
||||||
auth_header, _ = self._ensure_access_token(
|
auth_header, _ = self._ensure_access_token(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
optional_params = optional_params or {
|
optional_params = optional_params or {
|
||||||
"sampleCount": 1
|
"sampleCount": 1
|
||||||
|
@ -180,7 +182,9 @@ class VertexImageGeneration(VertexLLM):
|
||||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||||
"""
|
"""
|
||||||
auth_header, _ = self._ensure_access_token(
|
auth_header, _ = self._ensure_access_token(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
optional_params = optional_params or {
|
optional_params = optional_params or {
|
||||||
"sampleCount": 1
|
"sampleCount": 1
|
||||||
|
|
|
@ -47,6 +47,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
_auth_header, vertex_project = self._ensure_access_token(
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
credentials=vertex_credentials,
|
credentials=vertex_credentials,
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_header, url = self._get_token_and_url(
|
auth_header, url = self._get_token_and_url(
|
||||||
|
|
|
@ -65,10 +65,10 @@ class VertexTextToSpeechAPI(VertexLLM):
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
####### Authenticate with Vertex AI ########
|
####### Authenticate with Vertex AI ########
|
||||||
|
|
||||||
_auth_header, vertex_project = self._ensure_access_token(
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
credentials=vertex_credentials,
|
credentials=vertex_credentials,
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_header, _ = self._get_token_and_url(
|
auth_header, _ = self._get_token_and_url(
|
||||||
|
|
|
@ -292,7 +292,9 @@ def completion(
|
||||||
vertex_httpx_logic = VertexLLM()
|
vertex_httpx_logic = VertexLLM()
|
||||||
|
|
||||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
|
|
||||||
anthropic_chat_completions = AnthropicChatCompletion()
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
|
|
|
@ -105,7 +105,9 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
vertex_httpx_logic = VertexLLM()
|
vertex_httpx_logic = VertexLLM()
|
||||||
|
|
||||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_like_chat_completions = DatabricksChatCompletion()
|
openai_like_chat_completions = DatabricksChatCompletion()
|
||||||
|
|
|
@ -109,11 +109,18 @@ class VertexBase(BaseLLM):
|
||||||
credentials.refresh(Request())
|
credentials.refresh(Request())
|
||||||
|
|
||||||
def _ensure_access_token(
|
def _ensure_access_token(
|
||||||
self, credentials: Optional[str], project_id: Optional[str]
|
self,
|
||||||
|
credentials: Optional[str],
|
||||||
|
project_id: Optional[str],
|
||||||
|
custom_llm_provider: Literal[
|
||||||
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
], # if it's vertex_ai or gemini (google ai studio)
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Returns auth token and project id
|
Returns auth token and project id
|
||||||
"""
|
"""
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
return "", ""
|
||||||
if self.access_token is not None:
|
if self.access_token is not None:
|
||||||
if project_id is not None:
|
if project_id is not None:
|
||||||
return self.access_token, project_id
|
return self.access_token, project_id
|
||||||
|
@ -222,11 +229,18 @@ class VertexBase(BaseLLM):
|
||||||
return auth_header, url
|
return auth_header, url
|
||||||
|
|
||||||
async def _ensure_access_token_async(
|
async def _ensure_access_token_async(
|
||||||
self, credentials: Optional[str], project_id: Optional[str]
|
self,
|
||||||
|
credentials: Optional[str],
|
||||||
|
project_id: Optional[str],
|
||||||
|
custom_llm_provider: Literal[
|
||||||
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
], # if it's vertex_ai or gemini (google ai studio)
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Async version of _ensure_access_token
|
Async version of _ensure_access_token
|
||||||
"""
|
"""
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
return "", ""
|
||||||
if self.access_token is not None:
|
if self.access_token is not None:
|
||||||
if project_id is not None:
|
if project_id is not None:
|
||||||
return self.access_token, project_id
|
return self.access_token, project_id
|
||||||
|
|
|
@ -152,7 +152,9 @@ async def vertex_proxy_route(
|
||||||
|
|
||||||
_auth_header, vertex_project = (
|
_auth_header, vertex_project = (
|
||||||
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue