fix - vertex ai cache clients

This commit is contained in:
Ishaan Jaff 2024-05-30 21:22:32 -07:00
parent c5c0d0c01d
commit d4d9b098b1
2 changed files with 56 additions and 20 deletions

View file

@ -3,7 +3,7 @@ import json
from enum import Enum
import requests # type: ignore
import time
from typing import Callable, Optional, Union, List, Literal
from typing import Callable, Optional, Union, List, Literal, Any
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
@ -527,6 +527,19 @@ def _gemini_vision_convert_messages(messages: list):
raise e
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
return _cache_key
def _get_client_from_cache(client_cache_key: str):
return litellm.in_memory_llm_clients_cache.get(client_cache_key, None)
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
litellm.in_memory_llm_clients_cache[client_cache_key] = vertex_llm_model
def completion(
model: str,
messages: list,
@ -580,23 +593,32 @@ def completion(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
_cache_key = _get_client_cache_key(
model=model, vertex_project=vertex_project, vertex_location=vertex_location
)
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
if _vertex_llm_model_object is None:
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
## Load Config
config = litellm.VertexAIConfig.get_config()
@ -639,23 +661,27 @@ def completion(
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
llm_model = GenerativeModel(model)
llm_model = _vertex_llm_model_object or GenerativeModel(model)
mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models:
llm_model = ChatModel.from_pretrained(model)
llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
elif model in litellm.vertex_text_models:
llm_model = TextGenerationModel.from_pretrained(model)
llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
model
)
mode = "text"
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_text_models:
llm_model = CodeGenerationModel.from_pretrained(model)
llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
model
)
mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
llm_model = CodeChatModel.from_pretrained(model)
llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
elif model == "private":
@ -1034,6 +1060,15 @@ async def async_completion(
tools=tools,
)
_cache_key = _get_client_cache_key(
model=model,
vertex_project=vertex_project,
vertex_location=vertex_location,
)
_set_client_in_cache(
client_cache_key=_cache_key, vertex_llm_model=llm_model
)
if tools is not None and bool(
getattr(response.candidates[0].content.parts[0], "function_call", None)
):