fix(main.py): support env var 'VERTEX_PROJECT' and 'VERTEX_LOCATION'

This commit is contained in:
Krrish Dholakia 2024-05-10 07:57:41 -07:00
parent a671046b45
commit 9a31f3d3d9
3 changed files with 47 additions and 3 deletions

View file

@ -2960,17 +2960,20 @@ def embedding(
or optional_params.pop("vertex_ai_project", None) or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT") or get_secret("VERTEXAI_PROJECT")
or get_secret("VERTEX_PROJECT")
) )
vertex_ai_location = ( vertex_ai_location = (
optional_params.pop("vertex_location", None) optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None) or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION") or get_secret("VERTEXAI_LOCATION")
or get_secret("VERTEX_LOCATION")
) )
vertex_credentials = ( vertex_credentials = (
optional_params.pop("vertex_credentials", None) optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None) or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS") or get_secret("VERTEXAI_CREDENTIALS")
or get_secret("VERTEX_CREDENTIALS")
) )
response = vertex_ai.embedding( response = vertex_ai.embedding(

View file

@ -113,6 +113,49 @@ async def get_response():
], ],
) )
return response return response
except litellm.UnprocessableEntityError as e:
pass
except Exception as e:
pytest.fail(f"An error occurred - {str(e)}")
@pytest.mark.asyncio
async def test_get_router_response():
model = "claude-3-sonnet@20240229"
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "asia-southeast1"
json_obj = get_vertex_ai_creds_json()
vertex_credentials = json.dumps(json_obj)
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
try:
router = litellm.Router(
model_list=[
{
"model_name": "sonnet",
"litellm_params": {
"model": "vertex_ai/claude-3-sonnet@20240229",
"vertex_ai_project": vertex_ai_project,
"vertex_ai_location": vertex_ai_location,
"vertex_credentials": vertex_credentials,
},
}
]
)
response = await router.acompletion(
model="sonnet",
messages=[
{
"role": "system",
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
},
{"role": "user", "content": prompt},
],
)
print(f"\n\nResponse: {response}\n\n")
except litellm.UnprocessableEntityError as e: except litellm.UnprocessableEntityError as e:
pass pass
except Exception as e: except Exception as e:

View file

@ -5769,9 +5769,7 @@ def get_optional_params(
extra_body # openai client supports `extra_body` param extra_body # openai client supports `extra_body` param
) )
else: # assume passing in params for openai/azure openai else: # assume passing in params for openai/azure openai
print_verbose(
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
)
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider="openai" model=model, custom_llm_provider="openai"
) )