mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(utils.py): unify common auth params across azure/vertex_ai/bedrock/watsonx
This commit is contained in:
parent
c9d7437d16
commit
48f19cf839
8 changed files with 194 additions and 20 deletions
|
@ -2654,6 +2654,7 @@ def test_completion_palm_stream():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_watsonx():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||
|
@ -2671,10 +2672,57 @@ def test_completion_watsonx():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider, model, project, region_name, token",
|
||||
[
|
||||
("azure", "chatgpt-v-2", None, None, "test-token"),
|
||||
("vertex_ai", "anthropic-claude-3", "adroit-crow-1", "us-east1", None),
|
||||
("watsonx", "ibm/granite", "96946574", "dallas", "1234"),
|
||||
("bedrock", "anthropic.claude-3", None, "us-east-1", None),
|
||||
],
|
||||
)
|
||||
def test_unified_auth_params(provider, model, project, region_name, token):
|
||||
"""
|
||||
Check if params = ["project", "region_name", "token"]
|
||||
are correctly translated for = ["azure", "vertex_ai", "watsonx", "aws"]
|
||||
|
||||
tests get_optional_params
|
||||
"""
|
||||
data = {
|
||||
"project": project,
|
||||
"region_name": region_name,
|
||||
"token": token,
|
||||
"custom_llm_provider": provider,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
translated_optional_params = litellm.utils.get_optional_params(**data)
|
||||
|
||||
if provider == "azure":
|
||||
special_auth_params = (
|
||||
litellm.AzureOpenAIConfig().get_mapped_special_auth_params()
|
||||
)
|
||||
elif provider == "bedrock":
|
||||
special_auth_params = (
|
||||
litellm.AmazonBedrockGlobalConfig().get_mapped_special_auth_params()
|
||||
)
|
||||
elif provider == "vertex_ai":
|
||||
special_auth_params = litellm.VertexAIConfig().get_mapped_special_auth_params()
|
||||
elif provider == "watsonx":
|
||||
special_auth_params = (
|
||||
litellm.IBMWatsonXAIConfig().get_mapped_special_auth_params()
|
||||
)
|
||||
|
||||
for param, value in special_auth_params.items():
|
||||
assert param in data
|
||||
assert value in translated_optional_params
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_watsonx():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/deployment/"+os.getenv("WATSONX_DEPLOYMENT_ID")
|
||||
model_name = "watsonx/deployment/" + os.getenv("WATSONX_DEPLOYMENT_ID")
|
||||
print("testing watsonx")
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue