forked from phoenix/litellm-mirror
fix(main.py): accept vertex service account credentials as json string
allows us to dynamically set vertex ai credentials
This commit is contained in:
parent
c769c3f39b
commit
50081479f9
3 changed files with 71 additions and 15 deletions
|
@ -140,6 +140,7 @@ def completion(
|
|||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -217,11 +218,24 @@ def completion(
|
|||
## Completion Call
|
||||
|
||||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
||||
)
|
||||
if client is None:
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
|
||||
creds = (
|
||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj
|
||||
)
|
||||
)
|
||||
|
||||
vertexai.init(credentials=creds)
|
||||
vertex_ai_client = AnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location
|
||||
project_id=vertex_project,
|
||||
region=vertex_location,
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
|
|
@ -1678,6 +1678,11 @@ def completion(
|
|||
)
|
||||
|
||||
if "claude-3" in model:
|
||||
vertex_credentials = (
|
||||
optional_params.pop("vertex_credentials", None)
|
||||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
model_response = vertex_ai_anthropic.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1689,6 +1694,7 @@ def completion(
|
|||
encoding=encoding,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
|
|
|
@ -23,6 +23,40 @@ user_message = "Write a short poem about the sky"
|
|||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
|
||||
def get_vertex_ai_creds_json() -> dict:
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + "/vertex_key.json"
|
||||
|
||||
# Read the existing content of the file or create an empty dictionary
|
||||
try:
|
||||
with open(vertex_key_path, "r") as file:
|
||||
# Read the file content
|
||||
print("Read vertexai file path")
|
||||
content = file.read()
|
||||
|
||||
# If the file is empty or not valid JSON, create an empty dictionary
|
||||
if not content or not content.strip():
|
||||
service_account_key_data = {}
|
||||
else:
|
||||
# Attempt to load the existing JSON content
|
||||
file.seek(0)
|
||||
service_account_key_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
# If the file doesn't exist, create an empty dictionary
|
||||
service_account_key_data = {}
|
||||
|
||||
# Update the service_account_key_data with environment variables
|
||||
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
||||
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
||||
private_key = private_key.replace("\\n", "\n")
|
||||
service_account_key_data["private_key_id"] = private_key_id
|
||||
service_account_key_data["private_key"] = private_key
|
||||
|
||||
return service_account_key_data
|
||||
|
||||
|
||||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
|
@ -85,9 +119,9 @@ async def get_response():
|
|||
pytest.fail(f"An error occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
)
|
||||
# @pytest.mark.skip(
|
||||
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
# )
|
||||
def test_vertex_ai_anthropic():
|
||||
load_vertex_ai_credentials()
|
||||
|
||||
|
@ -95,6 +129,8 @@ def test_vertex_ai_anthropic():
|
|||
|
||||
vertex_ai_project = "adroit-crow-413218"
|
||||
vertex_ai_location = "asia-southeast1"
|
||||
json_obj = get_vertex_ai_creds_json()
|
||||
vertex_credentials = json.dumps(json_obj)
|
||||
|
||||
response = completion(
|
||||
model="vertex_ai/" + model,
|
||||
|
@ -102,13 +138,14 @@ def test_vertex_ai_anthropic():
|
|||
temperature=0.7,
|
||||
vertex_ai_project=vertex_ai_project,
|
||||
vertex_ai_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
print("\nModel Response", response)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
)
|
||||
# @pytest.mark.skip(
|
||||
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
# )
|
||||
def test_vertex_ai_anthropic_streaming():
|
||||
load_vertex_ai_credentials()
|
||||
|
||||
|
@ -137,9 +174,9 @@ def test_vertex_ai_anthropic_streaming():
|
|||
# test_vertex_ai_anthropic_streaming()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
)
|
||||
# @pytest.mark.skip(
|
||||
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
# )
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_ai_anthropic_async():
|
||||
load_vertex_ai_credentials()
|
||||
|
@ -162,9 +199,9 @@ async def test_vertex_ai_anthropic_async():
|
|||
# asyncio.run(test_vertex_ai_anthropic_async())
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
)
|
||||
# @pytest.mark.skip(
|
||||
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
# )
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_ai_anthropic_async_streaming():
|
||||
load_vertex_ai_credentials()
|
||||
|
@ -180,7 +217,6 @@ async def test_vertex_ai_anthropic_async_streaming():
|
|||
temperature=0.7,
|
||||
vertex_ai_project=vertex_ai_project,
|
||||
vertex_ai_location=vertex_ai_location,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue