diff --git a/docs/my-website/docs/set_keys.md b/docs/my-website/docs/set_keys.md index 4c8cc42fe..7686bf704 100644 --- a/docs/my-website/docs/set_keys.md +++ b/docs/my-website/docs/set_keys.md @@ -5,6 +5,9 @@ LiteLLM allows you to specify the following: * API Base * API Version * API Type +* Project +* Location +* Token Useful Helper functions: * [`check_valid_key()`](#check_valid_key) @@ -43,6 +46,24 @@ os.environ['AZURE_API_TYPE'] = "azure" # [OPTIONAL] os.environ['OPENAI_API_BASE'] = "https://openai-gpt-4-test2-v-12.openai.azure.com/" ``` +### Setting Project, Location, Token + +For cloud providers: +- Azure +- Bedrock +- GCP +- Watson AI + +you might need to set additional parameters. LiteLLM provides a common set of params, that we map across all providers. + +| | LiteLLM param | Watson | Vertex AI | Azure | Bedrock | +|------|--------------|--------------|--------------|--------------|--------------| +| Project | project | watsonx_project | vertex_project | n/a | n/a | +| Region | region_name | watsonx_region_name | vertex_location | n/a | aws_region_name | +| Token | token | watsonx_token or token | n/a | azure_ad_token | n/a | + +If you want, you can call them by their provider-specific params as well. + ## litellm variables ### litellm.api_key diff --git a/litellm/__init__.py b/litellm/__init__.py index 5f23ae33e..6a89506c9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -58,6 +58,7 @@ max_tokens = 256 # OpenAI Defaults drop_params = False modify_params = False retry = True +### AUTH ### api_key: Optional[str] = None openai_key: Optional[str] = None azure_key: Optional[str] = None @@ -76,6 +77,10 @@ cloudflare_api_key: Optional[str] = None baseten_key: Optional[str] = None aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None +common_cloud_provider_auth_params: dict = { + "params": ["project", "region_name", "token"], + "providers": ["vertex_ai", "bedrock", "watsonx", "azure"], +} use_client: bool = False ssl_verify: bool = True disable_streaming_logging: bool = False @@ -654,6 +659,7 @@ from .llms.bedrock import ( AmazonLlamaConfig, AmazonStabilityConfig, AmazonMistralConfig, + AmazonBedrockGlobalConfig, ) from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig, AzureOpenAIError diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 7f268c25a..0fe5c4e7e 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -96,6 +96,15 @@ class AzureOpenAIConfig(OpenAIConfig): top_p, ) + def get_mapped_special_auth_params(self) -> dict: + return {"token": "azure_ad_token"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "token": + optional_params["azure_ad_token"] = value + return optional_params + def select_azure_base_url_or_endpoint(azure_client_params: dict): # azure_client_params = { diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 149b68472..894114559 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -29,6 +29,24 @@ class BedrockError(Exception): ) # Call the base class constructor with the parameters it needs +class AmazonBedrockGlobalConfig: + def __init__(self): + pass + + def get_mapped_special_auth_params(self) -> dict: + """ + Mapping of common auth params across bedrock/vertex/azure/watsonx + """ + return {"region_name": "aws_region_name"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + class AmazonTitanConfig: """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index e8e9ca582..1dbb1cc7a 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -184,6 +184,20 @@ class VertexAIConfig: pass return optional_params + def get_mapped_special_auth_params(self) -> dict: + """ + Common auth params across bedrock/vertex_ai/azure/watsonx + """ + return {"project": "vertex_project", "region_name": "vertex_location"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + import asyncio @@ -529,7 +543,7 @@ def completion( "instances": instances, "vertex_location": vertex_location, "vertex_project": vertex_project, - "safety_settings":safety_settings, + "safety_settings": safety_settings, **optional_params, } if optional_params.get("stream", False) is True: diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 28061919e..ac38a2a8f 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -131,6 +131,24 @@ class IBMWatsonXAIConfig: "stream", # equivalent to stream ] + def get_mapped_special_auth_params(self) -> dict: + """ + Common auth params across bedrock/vertex_ai/azure/watsonx + """ + return { + "project": "watsonx_project", + "region_name": "watsonx_region_name", + "token": "watsonx_token", + } + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): # handle anthropic prompts and amazon titan prompts diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 531710841..14be9592b 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2654,6 +2654,7 @@ def test_completion_palm_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") + def test_completion_watsonx(): litellm.set_verbose = True model_name = "watsonx/ibm/granite-13b-chat-v2" @@ -2671,10 +2672,57 @@ def test_completion_watsonx(): except Exception as e: pytest.fail(f"Error occurred: {e}") + +@pytest.mark.parametrize( + "provider, model, project, region_name, token", + [ + ("azure", "chatgpt-v-2", None, None, "test-token"), + ("vertex_ai", "anthropic-claude-3", "adroit-crow-1", "us-east1", None), + ("watsonx", "ibm/granite", "96946574", "dallas", "1234"), + ("bedrock", "anthropic.claude-3", None, "us-east-1", None), + ], +) +def test_unified_auth_params(provider, model, project, region_name, token): + """ + Check if params = ["project", "region_name", "token"] + are correctly translated for = ["azure", "vertex_ai", "watsonx", "aws"] + + tests get_optional_params + """ + data = { + "project": project, + "region_name": region_name, + "token": token, + "custom_llm_provider": provider, + "model": model, + } + + translated_optional_params = litellm.utils.get_optional_params(**data) + + if provider == "azure": + special_auth_params = ( + litellm.AzureOpenAIConfig().get_mapped_special_auth_params() + ) + elif provider == "bedrock": + special_auth_params = ( + litellm.AmazonBedrockGlobalConfig().get_mapped_special_auth_params() + ) + elif provider == "vertex_ai": + special_auth_params = litellm.VertexAIConfig().get_mapped_special_auth_params() + elif provider == "watsonx": + special_auth_params = ( + litellm.IBMWatsonXAIConfig().get_mapped_special_auth_params() + ) + + for param, value in special_auth_params.items(): + assert param in data + assert value in translated_optional_params + + @pytest.mark.asyncio async def test_acompletion_watsonx(): litellm.set_verbose = True - model_name = "watsonx/deployment/"+os.getenv("WATSONX_DEPLOYMENT_ID") + model_name = "watsonx/deployment/" + os.getenv("WATSONX_DEPLOYMENT_ID") print("testing watsonx") try: response = await litellm.acompletion( diff --git a/litellm/utils.py b/litellm/utils.py index 9f176c194..ccd6bd3dc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4619,7 +4619,36 @@ def get_optional_params( k.startswith("vertex_") and custom_llm_provider != "vertex_ai" ): # allow dynamically setting vertex ai init logic continue + passed_params[k] = v + + optional_params = {} + + common_auth_dict = litellm.common_cloud_provider_auth_params + if custom_llm_provider in common_auth_dict["providers"]: + """ + Check if params = ["project", "region_name", "token"] + and correctly translate for = ["azure", "vertex_ai", "watsonx", "aws"] + """ + if custom_llm_provider == "azure": + optional_params = litellm.AzureOpenAIConfig().map_special_auth_params( + non_default_params=passed_params, optional_params=optional_params + ) + elif custom_llm_provider == "bedrock": + optional_params = ( + litellm.AmazonBedrockGlobalConfig().map_special_auth_params( + non_default_params=passed_params, optional_params=optional_params + ) + ) + elif custom_llm_provider == "vertex_ai": + optional_params = litellm.VertexAIConfig().map_special_auth_params( + non_default_params=passed_params, optional_params=optional_params + ) + elif custom_llm_provider == "watsonx": + optional_params = litellm.IBMWatsonXAIConfig().map_special_auth_params( + non_default_params=passed_params, optional_params=optional_params + ) + default_params = { "functions": None, "function_call": None, @@ -4655,7 +4684,7 @@ def get_optional_params( and v != default_params[k] ) } - optional_params = {} + ## raise exception if function calling passed in for a provider that doesn't support it if ( "functions" in non_default_params @@ -5446,17 +5475,21 @@ def get_optional_params( optional_params["random_seed"] = seed if stop is not None: optional_params["stop_sequences"] = stop - + # WatsonX-only parameters extra_body = {} if "decoding_method" in passed_params: extra_body["decoding_method"] = passed_params.pop("decoding_method") - if "min_tokens" in passed_params or "min_new_tokens" in passed_params: - extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens")) + if "min_tokens" in passed_params or "min_new_tokens" in passed_params: + extra_body["min_new_tokens"] = passed_params.pop( + "min_tokens", passed_params.pop("min_new_tokens") + ) if "top_k" in passed_params: extra_body["top_k"] = passed_params.pop("top_k") if "truncate_input_tokens" in passed_params: - extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens") + extra_body["truncate_input_tokens"] = passed_params.pop( + "truncate_input_tokens" + ) if "length_penalty" in passed_params: extra_body["length_penalty"] = passed_params.pop("length_penalty") if "time_limit" in passed_params: @@ -5464,7 +5497,7 @@ def get_optional_params( if "return_options" in passed_params: extra_body["return_options"] = passed_params.pop("return_options") optional_params["extra_body"] = ( - extra_body # openai client supports `extra_body` param + extra_body # openai client supports `extra_body` param ) else: # assume passing in params for openai/azure openai print_verbose( @@ -9793,7 +9826,7 @@ class CustomStreamWrapper: "is_finished": chunk["is_finished"], "finish_reason": finish_reason, } - + def handle_watsonx_stream(self, chunk): try: if isinstance(chunk, dict): @@ -9801,19 +9834,21 @@ class CustomStreamWrapper: elif isinstance(chunk, (str, bytes)): if isinstance(chunk, bytes): chunk = chunk.decode("utf-8") - if 'generated_text' in chunk: - response = chunk.replace('data: ', '').strip() + if "generated_text" in chunk: + response = chunk.replace("data: ", "").strip() parsed_response = json.loads(response) else: return {"text": "", "is_finished": False} else: print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") - raise ValueError(f"Unable to parse response. Original response: {chunk}") + raise ValueError( + f"Unable to parse response. Original response: {chunk}" + ) results = parsed_response.get("results", []) if len(results) > 0: text = results[0].get("generated_text", "") finish_reason = results[0].get("stop_reason") - is_finished = finish_reason != 'not_finished' + is_finished = finish_reason != "not_finished" return { "text": text, "is_finished": is_finished, @@ -10085,14 +10120,19 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") if response_obj.get("prompt_tokens") is not None: - prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) - model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) + prompt_token_count = getattr( + model_response.usage, "prompt_tokens", 0 + ) + model_response.usage.prompt_tokens = ( + prompt_token_count + response_obj["prompt_tokens"] + ) if response_obj.get("completion_tokens") is not None: - model_response.usage.completion_tokens = response_obj["completion_tokens"] - model_response.usage.total_tokens = ( - getattr(model_response.usage, "prompt_tokens", 0) - + getattr(model_response.usage, "completion_tokens", 0) - ) + model_response.usage.completion_tokens = response_obj[ + "completion_tokens" + ] + model_response.usage.total_tokens = getattr( + model_response.usage, "prompt_tokens", 0 + ) + getattr(model_response.usage, "completion_tokens", 0) if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai":