forked from phoenix/litellm-mirror
fix(main.py): accept vertex service account credentials as json string
allows us to dynamically set vertex ai credentials
This commit is contained in:
parent
c769c3f39b
commit
50081479f9
3 changed files with 71 additions and 15 deletions
|
@ -140,6 +140,7 @@ def completion(
|
||||||
logging_obj,
|
logging_obj,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
|
vertex_credentials=None,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -217,11 +218,24 @@ def completion(
|
||||||
## Completion Call
|
## Completion Call
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
||||||
)
|
)
|
||||||
if client is None:
|
if client is None:
|
||||||
|
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||||
|
import google.oauth2.service_account
|
||||||
|
|
||||||
|
json_obj = json.loads(vertex_credentials)
|
||||||
|
|
||||||
|
creds = (
|
||||||
|
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||||
|
json_obj
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
vertexai.init(credentials=creds)
|
||||||
vertex_ai_client = AnthropicVertex(
|
vertex_ai_client = AnthropicVertex(
|
||||||
project_id=vertex_project, region=vertex_location
|
project_id=vertex_project,
|
||||||
|
region=vertex_location,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vertex_ai_client = client
|
vertex_ai_client = client
|
||||||
|
|
|
@ -1678,6 +1678,11 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
if "claude-3" in model:
|
if "claude-3" in model:
|
||||||
|
vertex_credentials = (
|
||||||
|
optional_params.pop("vertex_credentials", None)
|
||||||
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
|
)
|
||||||
model_response = vertex_ai_anthropic.completion(
|
model_response = vertex_ai_anthropic.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1689,6 +1694,7 @@ def completion(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
vertex_project=vertex_ai_project,
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,6 +23,40 @@ user_message = "Write a short poem about the sky"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
|
||||||
|
|
||||||
|
def get_vertex_ai_creds_json() -> dict:
|
||||||
|
# Define the path to the vertex_key.json file
|
||||||
|
print("loading vertex ai credentials")
|
||||||
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
vertex_key_path = filepath + "/vertex_key.json"
|
||||||
|
|
||||||
|
# Read the existing content of the file or create an empty dictionary
|
||||||
|
try:
|
||||||
|
with open(vertex_key_path, "r") as file:
|
||||||
|
# Read the file content
|
||||||
|
print("Read vertexai file path")
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
# If the file is empty or not valid JSON, create an empty dictionary
|
||||||
|
if not content or not content.strip():
|
||||||
|
service_account_key_data = {}
|
||||||
|
else:
|
||||||
|
# Attempt to load the existing JSON content
|
||||||
|
file.seek(0)
|
||||||
|
service_account_key_data = json.load(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
# If the file doesn't exist, create an empty dictionary
|
||||||
|
service_account_key_data = {}
|
||||||
|
|
||||||
|
# Update the service_account_key_data with environment variables
|
||||||
|
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
|
||||||
|
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
|
||||||
|
private_key = private_key.replace("\\n", "\n")
|
||||||
|
service_account_key_data["private_key_id"] = private_key_id
|
||||||
|
service_account_key_data["private_key"] = private_key
|
||||||
|
|
||||||
|
return service_account_key_data
|
||||||
|
|
||||||
|
|
||||||
def load_vertex_ai_credentials():
|
def load_vertex_ai_credentials():
|
||||||
# Define the path to the vertex_key.json file
|
# Define the path to the vertex_key.json file
|
||||||
print("loading vertex ai credentials")
|
print("loading vertex ai credentials")
|
||||||
|
@ -85,9 +119,9 @@ async def get_response():
|
||||||
pytest.fail(f"An error occurred - {str(e)}")
|
pytest.fail(f"An error occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
# @pytest.mark.skip(
|
||||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||||
)
|
# )
|
||||||
def test_vertex_ai_anthropic():
|
def test_vertex_ai_anthropic():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
|
||||||
|
@ -95,6 +129,8 @@ def test_vertex_ai_anthropic():
|
||||||
|
|
||||||
vertex_ai_project = "adroit-crow-413218"
|
vertex_ai_project = "adroit-crow-413218"
|
||||||
vertex_ai_location = "asia-southeast1"
|
vertex_ai_location = "asia-southeast1"
|
||||||
|
json_obj = get_vertex_ai_creds_json()
|
||||||
|
vertex_credentials = json.dumps(json_obj)
|
||||||
|
|
||||||
response = completion(
|
response = completion(
|
||||||
model="vertex_ai/" + model,
|
model="vertex_ai/" + model,
|
||||||
|
@ -102,13 +138,14 @@ def test_vertex_ai_anthropic():
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
vertex_ai_project=vertex_ai_project,
|
vertex_ai_project=vertex_ai_project,
|
||||||
vertex_ai_location=vertex_ai_location,
|
vertex_ai_location=vertex_ai_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
)
|
)
|
||||||
print("\nModel Response", response)
|
print("\nModel Response", response)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
# @pytest.mark.skip(
|
||||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||||
)
|
# )
|
||||||
def test_vertex_ai_anthropic_streaming():
|
def test_vertex_ai_anthropic_streaming():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
|
||||||
|
@ -137,9 +174,9 @@ def test_vertex_ai_anthropic_streaming():
|
||||||
# test_vertex_ai_anthropic_streaming()
|
# test_vertex_ai_anthropic_streaming()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
# @pytest.mark.skip(
|
||||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||||
)
|
# )
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vertex_ai_anthropic_async():
|
async def test_vertex_ai_anthropic_async():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
@ -162,9 +199,9 @@ async def test_vertex_ai_anthropic_async():
|
||||||
# asyncio.run(test_vertex_ai_anthropic_async())
|
# asyncio.run(test_vertex_ai_anthropic_async())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
# @pytest.mark.skip(
|
||||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||||
)
|
# )
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vertex_ai_anthropic_async_streaming():
|
async def test_vertex_ai_anthropic_async_streaming():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
@ -180,7 +217,6 @@ async def test_vertex_ai_anthropic_async_streaming():
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
vertex_ai_project=vertex_ai_project,
|
vertex_ai_project=vertex_ai_project,
|
||||||
vertex_ai_location=vertex_ai_location,
|
vertex_ai_location=vertex_ai_location,
|
||||||
stream=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue