forked from phoenix/litellm-mirror
Merge pull request #3582 from BerriAI/litellm_explicit_region_name_setting
feat(router.py): allow setting model_region in litellm_params
This commit is contained in:
commit
86d0c0ae4e
12 changed files with 405 additions and 95 deletions
148
litellm/utils.py
148
litellm/utils.py
|
@ -110,7 +110,18 @@ try:
|
|||
except Exception as e:
|
||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||
|
||||
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO, Iterable
|
||||
from typing import (
|
||||
cast,
|
||||
List,
|
||||
Dict,
|
||||
Union,
|
||||
Optional,
|
||||
Literal,
|
||||
Any,
|
||||
BinaryIO,
|
||||
Iterable,
|
||||
Tuple,
|
||||
)
|
||||
from .caching import Cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
@ -5885,10 +5896,70 @@ def calculate_max_parallel_requests(
|
|||
return None
|
||||
|
||||
|
||||
def _is_region_eu(model_region: str) -> bool:
|
||||
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
for region in EU_Regions:
|
||||
if "europe" in model_region.lower():
|
||||
def _get_model_region(
|
||||
custom_llm_provider: str, litellm_params: LiteLLM_Params
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return the region for a model, for a given provider
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
# check 'vertex_location'
|
||||
vertex_ai_location = (
|
||||
litellm_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret("VERTEX_LOCATION")
|
||||
)
|
||||
if vertex_ai_location is not None and isinstance(vertex_ai_location, str):
|
||||
return vertex_ai_location
|
||||
elif custom_llm_provider == "bedrock":
|
||||
aws_region_name = litellm_params.aws_region_name
|
||||
if aws_region_name is not None:
|
||||
return aws_region_name
|
||||
elif custom_llm_provider == "watsonx":
|
||||
watsonx_region_name = litellm_params.watsonx_region_name
|
||||
if watsonx_region_name is not None:
|
||||
return watsonx_region_name
|
||||
return litellm_params.region_name
|
||||
|
||||
|
||||
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
|
||||
"""
|
||||
Return true/false if a deployment is in the EU
|
||||
"""
|
||||
if litellm_params.region_name == "eu":
|
||||
return True
|
||||
|
||||
## ELSE ##
|
||||
"""
|
||||
- get provider
|
||||
- get provider regions
|
||||
- return true if given region (get_provider_region) in eu region (config.get_eu_regions())
|
||||
"""
|
||||
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_params.model, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
model_region = _get_model_region(
|
||||
custom_llm_provider=custom_llm_provider, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if model_region is None:
|
||||
return False
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
eu_regions = litellm.AzureOpenAIConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
eu_regions = litellm.VertexAIConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "bedrock":
|
||||
eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "watsonx":
|
||||
eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions()
|
||||
else:
|
||||
return False
|
||||
|
||||
for region in eu_regions:
|
||||
if region in model_region.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -6314,8 +6385,23 @@ def get_llm_provider(
|
|||
custom_llm_provider: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> Tuple[str, str, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure'
|
||||
|
||||
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
|
||||
"""
|
||||
try:
|
||||
## IF LITELLM PARAMS GIVEN ##
|
||||
if litellm_params is not None:
|
||||
assert (
|
||||
custom_llm_provider is None and api_base is None and api_key is None
|
||||
), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden."
|
||||
custom_llm_provider = litellm_params.custom_llm_provider
|
||||
api_base = litellm_params.api_base
|
||||
api_key = litellm_params.api_key
|
||||
|
||||
dynamic_api_key = None
|
||||
# check if llm provider provided
|
||||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||
|
@ -6376,7 +6462,8 @@ def get_llm_provider(
|
|||
api_base
|
||||
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
|
||||
or "https://api.mistral.ai/v1"
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
# if api_base does not end with /v1 we add it
|
||||
if api_base is not None and not api_base.endswith(
|
||||
"/v1"
|
||||
|
@ -6399,10 +6486,30 @@ def get_llm_provider(
|
|||
or get_secret("TOGETHERAI_API_KEY")
|
||||
or get_secret("TOGETHER_AI_TOKEN")
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
# check if api base is a known openai compatible endpoint
|
||||
if api_base:
|
||||
|
@ -6426,7 +6533,22 @@ def get_llm_provider(
|
|||
elif endpoint == "api.deepseek.com/v1":
|
||||
custom_llm_provider = "deepseek"
|
||||
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(
|
||||
api_base
|
||||
)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(
|
||||
dynamic_api_key, str
|
||||
):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore
|
||||
|
||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||
## openai - chatcompletion + text completion
|
||||
|
@ -6517,6 +6639,16 @@ def get_llm_provider(
|
|||
),
|
||||
llm_provider="",
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.exceptions.BadRequestError):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue