diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index f416d1437..a56527a59 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -9,7 +9,7 @@ from litellm.utils import ( convert_to_model_response_object, TranscriptionResponse, ) -from typing import Callable, Optional, BinaryIO +from typing import Callable, Optional, BinaryIO, List from litellm import OpenAIConfig import litellm, json import httpx # type: ignore @@ -105,6 +105,12 @@ class AzureOpenAIConfig(OpenAIConfig): optional_params["azure_ad_token"] = value return optional_params + def get_eu_regions(self) -> List[str]: + """ + Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability + """ + return ["europe", "sweden", "switzerland", "france", "uk"] + def select_azure_base_url_or_endpoint(azure_client_params: dict): # azure_client_params = { diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 08433ba18..d2a83703a 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -52,6 +52,16 @@ class AmazonBedrockGlobalConfig: optional_params[mapped_params[param]] = value return optional_params + def get_eu_regions(self) -> List[str]: + """ + Source: https://www.aws-services.info/bedrock.html + """ + return [ + "eu-west-1", + "eu-west-3", + "eu-central-1", + ] + class AmazonTitanConfig: """ diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index a61c07df0..d3bb2c78a 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -198,6 +198,23 @@ class VertexAIConfig: optional_params[mapped_params[param]] = value return optional_params + def get_eu_regions(self) -> List[str]: + """ + Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions + """ + return [ + "europe-central2", + "europe-north1", + "europe-southwest1", + "europe-west1", + "europe-west2", + "europe-west3", + "europe-west4", + "europe-west6", + "europe-west8", + "europe-west9", + ] + import asyncio diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 99f2d18ba..a12676fa0 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -149,6 +149,15 @@ class IBMWatsonXAIConfig: optional_params[mapped_params[param]] = value return optional_params + def get_eu_regions(self) -> List[str]: + """ + Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability + """ + return [ + "eu-de", + "eu-gb", + ] + def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): # handle anthropic prompts and amazon titan prompts diff --git a/litellm/router.py b/litellm/router.py index e0abc2e3b..0b5846db9 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2329,7 +2329,7 @@ class Router: ) # cache for 1 hr else: - _api_key = api_key + _api_key = api_key # type: ignore if _api_key is not None and isinstance(_api_key, str): # only show first 5 chars of api_key _api_key = _api_key[:8] + "*" * 15 @@ -2953,7 +2953,7 @@ class Router: ): # check if in allowed_model_region if ( - _is_region_eu(model_region=_litellm_params["region_name"]) + _is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params)) == False ): invalid_model_indices.append(idx) diff --git a/litellm/types/router.py b/litellm/types/router.py index dbf36f17c..e8f3ff641 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -132,6 +132,8 @@ class GenericLiteLLMParams(BaseModel): aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None aws_region_name: Optional[str] = None + ## IBM WATSONX ## + watsonx_region_name: Optional[str] = None ## CUSTOM PRICING ## input_cost_per_token: Optional[float] = None output_cost_per_token: Optional[float] = None @@ -161,6 +163,8 @@ class GenericLiteLLMParams(BaseModel): aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_region_name: Optional[str] = None, + ## IBM WATSONX ## + watsonx_region_name: Optional[str] = None, input_cost_per_token: Optional[float] = None, output_cost_per_token: Optional[float] = None, input_cost_per_second: Optional[float] = None, diff --git a/litellm/utils.py b/litellm/utils.py index 1c9c3df92..2704ccbcb 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -107,7 +107,18 @@ try: except Exception as e: verbose_logger.debug(f"Exception import enterprise features {str(e)}") -from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO, Iterable +from typing import ( + cast, + List, + Dict, + Union, + Optional, + Literal, + Any, + BinaryIO, + Iterable, + Tuple, +) from .caching import Cache from concurrent.futures import ThreadPoolExecutor @@ -5880,13 +5891,70 @@ def calculate_max_parallel_requests( return None -def _is_region_eu(model_region: str) -> bool: - if model_region == "eu": +def _get_model_region( + custom_llm_provider: str, litellm_params: LiteLLM_Params +) -> Optional[str]: + """ + Return the region for a model, for a given provider + """ + if custom_llm_provider == "vertex_ai": + # check 'vertex_location' + vertex_ai_location = ( + litellm_params.vertex_location + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + or get_secret("VERTEX_LOCATION") + ) + if vertex_ai_location is not None and isinstance(vertex_ai_location, str): + return vertex_ai_location + elif custom_llm_provider == "bedrock": + aws_region_name = litellm_params.aws_region_name + if aws_region_name is not None: + return aws_region_name + elif custom_llm_provider == "watsonx": + watsonx_region_name = litellm_params.watsonx_region_name + if watsonx_region_name is not None: + return watsonx_region_name + return litellm_params.region_name + + +def _is_region_eu(litellm_params: LiteLLM_Params) -> bool: + """ + Return true/false if a deployment is in the EU + """ + if litellm_params.region_name == "eu": return True - EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"] - for region in EU_Regions: - if "europe" in model_region.lower(): + ## ELSE ## + """ + - get provider + - get provider regions + - return true if given region (get_provider_region) in eu region (config.get_eu_regions()) + """ + model, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=litellm_params.model, litellm_params=litellm_params + ) + + model_region = _get_model_region( + custom_llm_provider=custom_llm_provider, litellm_params=litellm_params + ) + + if model_region is None: + return False + + if custom_llm_provider == "azure": + eu_regions = litellm.AzureOpenAIConfig().get_eu_regions() + elif custom_llm_provider == "vertex_ai": + eu_regions = litellm.VertexAIConfig().get_eu_regions() + elif custom_llm_provider == "bedrock": + eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions() + elif custom_llm_provider == "watsonx": + eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions() + else: + return False + + for region in eu_regions: + if region in model_region.lower(): return True return False @@ -6312,8 +6380,23 @@ def get_llm_provider( custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, -): + litellm_params: Optional[LiteLLM_Params] = None, +) -> Tuple[str, str, Optional[str], Optional[str]]: + """ + Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure' + + For router -> Can also give the whole litellm param dict -> this function will extract the relevant details + """ try: + ## IF LITELLM PARAMS GIVEN ## + if litellm_params is not None: + assert ( + custom_llm_provider is None and api_base is None and api_key is None + ), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden." + custom_llm_provider = litellm_params.custom_llm_provider + api_base = litellm_params.api_base + api_key = litellm_params.api_key + dynamic_api_key = None # check if llm provider provided # AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere @@ -6374,7 +6457,8 @@ def get_llm_provider( api_base or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral or "https://api.mistral.ai/v1" - ) + ) # type: ignore + # if api_base does not end with /v1 we add it if api_base is not None and not api_base.endswith( "/v1" @@ -6397,10 +6481,30 @@ def get_llm_provider( or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHER_AI_TOKEN") ) + if api_base is not None and not isinstance(api_base, str): + raise Exception( + "api base needs to be a string. api_base={}".format(api_base) + ) + if dynamic_api_key is not None and not isinstance(dynamic_api_key, str): + raise Exception( + "dynamic_api_key needs to be a string. dynamic_api_key={}".format( + dynamic_api_key + ) + ) return model, custom_llm_provider, dynamic_api_key, api_base elif model.split("/", 1)[0] in litellm.provider_list: custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] + if api_base is not None and not isinstance(api_base, str): + raise Exception( + "api base needs to be a string. api_base={}".format(api_base) + ) + if dynamic_api_key is not None and not isinstance(dynamic_api_key, str): + raise Exception( + "dynamic_api_key needs to be a string. dynamic_api_key={}".format( + dynamic_api_key + ) + ) return model, custom_llm_provider, dynamic_api_key, api_base # check if api base is a known openai compatible endpoint if api_base: @@ -6424,7 +6528,22 @@ def get_llm_provider( elif endpoint == "api.deepseek.com/v1": custom_llm_provider = "deepseek" dynamic_api_key = get_secret("DEEPSEEK_API_KEY") - return model, custom_llm_provider, dynamic_api_key, api_base + + if api_base is not None and not isinstance(api_base, str): + raise Exception( + "api base needs to be a string. api_base={}".format( + api_base + ) + ) + if dynamic_api_key is not None and not isinstance( + dynamic_api_key, str + ): + raise Exception( + "dynamic_api_key needs to be a string. dynamic_api_key={}".format( + dynamic_api_key + ) + ) + return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) ## openai - chatcompletion + text completion @@ -6515,6 +6634,16 @@ def get_llm_provider( ), llm_provider="", ) + if api_base is not None and not isinstance(api_base, str): + raise Exception( + "api base needs to be a string. api_base={}".format(api_base) + ) + if dynamic_api_key is not None and not isinstance(dynamic_api_key, str): + raise Exception( + "dynamic_api_key needs to be a string. dynamic_api_key={}".format( + dynamic_api_key + ) + ) return model, custom_llm_provider, dynamic_api_key, api_base except Exception as e: if isinstance(e, litellm.exceptions.BadRequestError):