mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix - vertex ai cache clients
This commit is contained in:
parent
c5c0d0c01d
commit
d4d9b098b1
2 changed files with 56 additions and 20 deletions
|
@ -3,7 +3,7 @@ import json
|
|||
from enum import Enum
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List, Literal
|
||||
from typing import Callable, Optional, Union, List, Literal, Any
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect # type: ignore
|
||||
|
@ -527,6 +527,19 @@ def _gemini_vision_convert_messages(messages: list):
|
|||
raise e
|
||||
|
||||
|
||||
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
||||
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
|
||||
return _cache_key
|
||||
|
||||
|
||||
def _get_client_from_cache(client_cache_key: str):
|
||||
return litellm.in_memory_llm_clients_cache.get(client_cache_key, None)
|
||||
|
||||
|
||||
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
|
||||
litellm.in_memory_llm_clients_cache[client_cache_key] = vertex_llm_model
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -580,23 +593,32 @@ def completion(
|
|||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||
)
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
_cache_key = _get_client_cache_key(
|
||||
model=model, vertex_project=vertex_project, vertex_location=vertex_location
|
||||
)
|
||||
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
|
||||
|
||||
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
if _vertex_llm_model_object is None:
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
|
||||
creds = (
|
||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
print_verbose(
|
||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
)
|
||||
vertexai.init(
|
||||
project=vertex_project, location=vertex_location, credentials=creds
|
||||
)
|
||||
else:
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
print_verbose(
|
||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
)
|
||||
vertexai.init(
|
||||
project=vertex_project, location=vertex_location, credentials=creds
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.VertexAIConfig.get_config()
|
||||
|
@ -639,23 +661,27 @@ def completion(
|
|||
model in litellm.vertex_language_models
|
||||
or model in litellm.vertex_vision_models
|
||||
):
|
||||
llm_model = GenerativeModel(model)
|
||||
llm_model = _vertex_llm_model_object or GenerativeModel(model)
|
||||
mode = "vision"
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
elif model in litellm.vertex_chat_models:
|
||||
llm_model = ChatModel.from_pretrained(model)
|
||||
llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_text_models:
|
||||
llm_model = TextGenerationModel.from_pretrained(model)
|
||||
llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
|
||||
model
|
||||
)
|
||||
mode = "text"
|
||||
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_code_text_models:
|
||||
llm_model = CodeGenerationModel.from_pretrained(model)
|
||||
llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
|
||||
model
|
||||
)
|
||||
mode = "text"
|
||||
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
|
||||
llm_model = CodeChatModel.from_pretrained(model)
|
||||
llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||
elif model == "private":
|
||||
|
@ -1034,6 +1060,15 @@ async def async_completion(
|
|||
tools=tools,
|
||||
)
|
||||
|
||||
_cache_key = _get_client_cache_key(
|
||||
model=model,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
)
|
||||
_set_client_in_cache(
|
||||
client_cache_key=_cache_key, vertex_llm_model=llm_model
|
||||
)
|
||||
|
||||
if tools is not None and bool(
|
||||
getattr(response.candidates[0].content.parts[0], "function_call", None)
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue