From 9a31f3d3d93b21eef8bcae45c0744a47c4c38ed9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 10 May 2024 07:57:41 -0700 Subject: [PATCH] fix(main.py): support env var 'VERTEX_PROJECT' and 'VERTEX_LOCATION' --- litellm/main.py | 3 ++ .../tests/test_amazing_vertex_completion.py | 43 +++++++++++++++++++ litellm/utils.py | 4 +- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index aa078d322..6fd4cdaab 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2960,17 +2960,20 @@ def embedding( or optional_params.pop("vertex_ai_project", None) or litellm.vertex_project or get_secret("VERTEXAI_PROJECT") + or get_secret("VERTEX_PROJECT") ) vertex_ai_location = ( optional_params.pop("vertex_location", None) or optional_params.pop("vertex_ai_location", None) or litellm.vertex_location or get_secret("VERTEXAI_LOCATION") + or get_secret("VERTEX_LOCATION") ) vertex_credentials = ( optional_params.pop("vertex_credentials", None) or optional_params.pop("vertex_ai_credentials", None) or get_secret("VERTEXAI_CREDENTIALS") + or get_secret("VERTEX_CREDENTIALS") ) response = vertex_ai.embedding( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 1d79653ea..91fd44474 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -113,6 +113,49 @@ async def get_response(): ], ) return response + + except litellm.UnprocessableEntityError as e: + pass + except Exception as e: + pytest.fail(f"An error occurred - {str(e)}") + + +@pytest.mark.asyncio +async def test_get_router_response(): + model = "claude-3-sonnet@20240229" + vertex_ai_project = "adroit-crow-413218" + vertex_ai_location = "asia-southeast1" + json_obj = get_vertex_ai_creds_json() + vertex_credentials = json.dumps(json_obj) + + prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n' + try: + router = litellm.Router( + model_list=[ + { + "model_name": "sonnet", + "litellm_params": { + "model": "vertex_ai/claude-3-sonnet@20240229", + "vertex_ai_project": vertex_ai_project, + "vertex_ai_location": vertex_ai_location, + "vertex_credentials": vertex_credentials, + }, + } + ] + ) + response = await router.acompletion( + model="sonnet", + messages=[ + { + "role": "system", + "content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.", + }, + {"role": "user", "content": prompt}, + ], + ) + + print(f"\n\nResponse: {response}\n\n") + except litellm.UnprocessableEntityError as e: pass except Exception as e: diff --git a/litellm/utils.py b/litellm/utils.py index 206001dbb..838d0fe55 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5769,9 +5769,7 @@ def get_optional_params( extra_body # openai client supports `extra_body` param ) else: # assume passing in params for openai/azure openai - print_verbose( - f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}" - ) + supported_params = get_supported_openai_params( model=model, custom_llm_provider="openai" )