feat(utils.py): unify common auth params across azure/vertex_ai/bedrock/watsonx

This commit is contained in:
Krrish Dholakia 2024-04-27 11:06:18 -07:00
parent c9d7437d16
commit 48f19cf839
8 changed files with 194 additions and 20 deletions

View file

@ -5,6 +5,9 @@ LiteLLM allows you to specify the following:
* API Base * API Base
* API Version * API Version
* API Type * API Type
* Project
* Location
* Token
Useful Helper functions: Useful Helper functions:
* [`check_valid_key()`](#check_valid_key) * [`check_valid_key()`](#check_valid_key)
@ -43,6 +46,24 @@ os.environ['AZURE_API_TYPE'] = "azure" # [OPTIONAL]
os.environ['OPENAI_API_BASE'] = "https://openai-gpt-4-test2-v-12.openai.azure.com/" os.environ['OPENAI_API_BASE'] = "https://openai-gpt-4-test2-v-12.openai.azure.com/"
``` ```
### Setting Project, Location, Token
For cloud providers:
- Azure
- Bedrock
- GCP
- Watson AI
you might need to set additional parameters. LiteLLM provides a common set of params, that we map across all providers.
| | LiteLLM param | Watson | Vertex AI | Azure | Bedrock |
|------|--------------|--------------|--------------|--------------|--------------|
| Project | project | watsonx_project | vertex_project | n/a | n/a |
| Region | region_name | watsonx_region_name | vertex_location | n/a | aws_region_name |
| Token | token | watsonx_token or token | n/a | azure_ad_token | n/a |
If you want, you can call them by their provider-specific params as well.
## litellm variables ## litellm variables
### litellm.api_key ### litellm.api_key

View file

@ -58,6 +58,7 @@ max_tokens = 256 # OpenAI Defaults
drop_params = False drop_params = False
modify_params = False modify_params = False
retry = True retry = True
### AUTH ###
api_key: Optional[str] = None api_key: Optional[str] = None
openai_key: Optional[str] = None openai_key: Optional[str] = None
azure_key: Optional[str] = None azure_key: Optional[str] = None
@ -76,6 +77,10 @@ cloudflare_api_key: Optional[str] = None
baseten_key: Optional[str] = None baseten_key: Optional[str] = None
aleph_alpha_key: Optional[str] = None aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None nlp_cloud_key: Optional[str] = None
common_cloud_provider_auth_params: dict = {
"params": ["project", "region_name", "token"],
"providers": ["vertex_ai", "bedrock", "watsonx", "azure"],
}
use_client: bool = False use_client: bool = False
ssl_verify: bool = True ssl_verify: bool = True
disable_streaming_logging: bool = False disable_streaming_logging: bool = False
@ -654,6 +659,7 @@ from .llms.bedrock import (
AmazonLlamaConfig, AmazonLlamaConfig,
AmazonStabilityConfig, AmazonStabilityConfig,
AmazonMistralConfig, AmazonMistralConfig,
AmazonBedrockGlobalConfig,
) )
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError from .llms.azure import AzureOpenAIConfig, AzureOpenAIError

View file

@ -96,6 +96,15 @@ class AzureOpenAIConfig(OpenAIConfig):
top_p, top_p,
) )
def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "token":
optional_params["azure_ad_token"] = value
return optional_params
def select_azure_base_url_or_endpoint(azure_client_params: dict): def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = { # azure_client_params = {

View file

@ -29,6 +29,24 @@ class BedrockError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AmazonBedrockGlobalConfig:
def __init__(self):
pass
def get_mapped_special_auth_params(self) -> dict:
"""
Mapping of common auth params across bedrock/vertex/azure/watsonx
"""
return {"region_name": "aws_region_name"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
class AmazonTitanConfig: class AmazonTitanConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1

View file

@ -184,6 +184,20 @@ class VertexAIConfig:
pass pass
return optional_params return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
import asyncio import asyncio
@ -529,7 +543,7 @@ def completion(
"instances": instances, "instances": instances,
"vertex_location": vertex_location, "vertex_location": vertex_location,
"vertex_project": vertex_project, "vertex_project": vertex_project,
"safety_settings":safety_settings, "safety_settings": safety_settings,
**optional_params, **optional_params,
} }
if optional_params.get("stream", False) is True: if optional_params.get("stream", False) is True:

View file

@ -131,6 +131,24 @@ class IBMWatsonXAIConfig:
"stream", # equivalent to stream "stream", # equivalent to stream
] ]
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {
"project": "watsonx_project",
"region_name": "watsonx_region_name",
"token": "watsonx_token",
}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts

View file

@ -2654,6 +2654,7 @@ def test_completion_palm_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_watsonx(): def test_completion_watsonx():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2" model_name = "watsonx/ibm/granite-13b-chat-v2"
@ -2671,10 +2672,57 @@ def test_completion_watsonx():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"provider, model, project, region_name, token",
[
("azure", "chatgpt-v-2", None, None, "test-token"),
("vertex_ai", "anthropic-claude-3", "adroit-crow-1", "us-east1", None),
("watsonx", "ibm/granite", "96946574", "dallas", "1234"),
("bedrock", "anthropic.claude-3", None, "us-east-1", None),
],
)
def test_unified_auth_params(provider, model, project, region_name, token):
"""
Check if params = ["project", "region_name", "token"]
are correctly translated for = ["azure", "vertex_ai", "watsonx", "aws"]
tests get_optional_params
"""
data = {
"project": project,
"region_name": region_name,
"token": token,
"custom_llm_provider": provider,
"model": model,
}
translated_optional_params = litellm.utils.get_optional_params(**data)
if provider == "azure":
special_auth_params = (
litellm.AzureOpenAIConfig().get_mapped_special_auth_params()
)
elif provider == "bedrock":
special_auth_params = (
litellm.AmazonBedrockGlobalConfig().get_mapped_special_auth_params()
)
elif provider == "vertex_ai":
special_auth_params = litellm.VertexAIConfig().get_mapped_special_auth_params()
elif provider == "watsonx":
special_auth_params = (
litellm.IBMWatsonXAIConfig().get_mapped_special_auth_params()
)
for param, value in special_auth_params.items():
assert param in data
assert value in translated_optional_params
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_watsonx(): async def test_acompletion_watsonx():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "watsonx/deployment/"+os.getenv("WATSONX_DEPLOYMENT_ID") model_name = "watsonx/deployment/" + os.getenv("WATSONX_DEPLOYMENT_ID")
print("testing watsonx") print("testing watsonx")
try: try:
response = await litellm.acompletion( response = await litellm.acompletion(

View file

@ -4619,7 +4619,36 @@ def get_optional_params(
k.startswith("vertex_") and custom_llm_provider != "vertex_ai" k.startswith("vertex_") and custom_llm_provider != "vertex_ai"
): # allow dynamically setting vertex ai init logic ): # allow dynamically setting vertex ai init logic
continue continue
passed_params[k] = v passed_params[k] = v
optional_params = {}
common_auth_dict = litellm.common_cloud_provider_auth_params
if custom_llm_provider in common_auth_dict["providers"]:
"""
Check if params = ["project", "region_name", "token"]
and correctly translate for = ["azure", "vertex_ai", "watsonx", "aws"]
"""
if custom_llm_provider == "azure":
optional_params = litellm.AzureOpenAIConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params
)
elif custom_llm_provider == "bedrock":
optional_params = (
litellm.AmazonBedrockGlobalConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params
)
)
elif custom_llm_provider == "vertex_ai":
optional_params = litellm.VertexAIConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params
)
elif custom_llm_provider == "watsonx":
optional_params = litellm.IBMWatsonXAIConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params
)
default_params = { default_params = {
"functions": None, "functions": None,
"function_call": None, "function_call": None,
@ -4655,7 +4684,7 @@ def get_optional_params(
and v != default_params[k] and v != default_params[k]
) )
} }
optional_params = {}
## raise exception if function calling passed in for a provider that doesn't support it ## raise exception if function calling passed in for a provider that doesn't support it
if ( if (
"functions" in non_default_params "functions" in non_default_params
@ -5451,12 +5480,16 @@ def get_optional_params(
extra_body = {} extra_body = {}
if "decoding_method" in passed_params: if "decoding_method" in passed_params:
extra_body["decoding_method"] = passed_params.pop("decoding_method") extra_body["decoding_method"] = passed_params.pop("decoding_method")
if "min_tokens" in passed_params or "min_new_tokens" in passed_params: if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens")) extra_body["min_new_tokens"] = passed_params.pop(
"min_tokens", passed_params.pop("min_new_tokens")
)
if "top_k" in passed_params: if "top_k" in passed_params:
extra_body["top_k"] = passed_params.pop("top_k") extra_body["top_k"] = passed_params.pop("top_k")
if "truncate_input_tokens" in passed_params: if "truncate_input_tokens" in passed_params:
extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens") extra_body["truncate_input_tokens"] = passed_params.pop(
"truncate_input_tokens"
)
if "length_penalty" in passed_params: if "length_penalty" in passed_params:
extra_body["length_penalty"] = passed_params.pop("length_penalty") extra_body["length_penalty"] = passed_params.pop("length_penalty")
if "time_limit" in passed_params: if "time_limit" in passed_params:
@ -5464,7 +5497,7 @@ def get_optional_params(
if "return_options" in passed_params: if "return_options" in passed_params:
extra_body["return_options"] = passed_params.pop("return_options") extra_body["return_options"] = passed_params.pop("return_options")
optional_params["extra_body"] = ( optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param extra_body # openai client supports `extra_body` param
) )
else: # assume passing in params for openai/azure openai else: # assume passing in params for openai/azure openai
print_verbose( print_verbose(
@ -9801,19 +9834,21 @@ class CustomStreamWrapper:
elif isinstance(chunk, (str, bytes)): elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes): if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8") chunk = chunk.decode("utf-8")
if 'generated_text' in chunk: if "generated_text" in chunk:
response = chunk.replace('data: ', '').strip() response = chunk.replace("data: ", "").strip()
parsed_response = json.loads(response) parsed_response = json.loads(response)
else: else:
return {"text": "", "is_finished": False} return {"text": "", "is_finished": False}
else: else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(f"Unable to parse response. Original response: {chunk}") raise ValueError(
f"Unable to parse response. Original response: {chunk}"
)
results = parsed_response.get("results", []) results = parsed_response.get("results", [])
if len(results) > 0: if len(results) > 0:
text = results[0].get("generated_text", "") text = results[0].get("generated_text", "")
finish_reason = results[0].get("stop_reason") finish_reason = results[0].get("stop_reason")
is_finished = finish_reason != 'not_finished' is_finished = finish_reason != "not_finished"
return { return {
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,
@ -10085,14 +10120,19 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj.get("prompt_tokens") is not None: if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) prompt_token_count = getattr(
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) model_response.usage, "prompt_tokens", 0
)
model_response.usage.prompt_tokens = (
prompt_token_count + response_obj["prompt_tokens"]
)
if response_obj.get("completion_tokens") is not None: if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"] model_response.usage.completion_tokens = response_obj[
model_response.usage.total_tokens = ( "completion_tokens"
getattr(model_response.usage, "prompt_tokens", 0) ]
+ getattr(model_response.usage, "completion_tokens", 0) model_response.usage.total_tokens = getattr(
) model_response.usage, "prompt_tokens", 0
) + getattr(model_response.usage, "completion_tokens", 0)
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":