Merge pull request #3582 from BerriAI/litellm_explicit_region_name_setting

feat(router.py): allow setting model_region in litellm_params
This commit is contained in:
Krish Dholakia 2024-05-11 11:36:22 -07:00 committed by GitHub
commit 86d0c0ae4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 405 additions and 95 deletions

View file

@ -110,7 +110,18 @@ try:
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO, Iterable
from typing import (
cast,
List,
Dict,
Union,
Optional,
Literal,
Any,
BinaryIO,
Iterable,
Tuple,
)
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -5885,10 +5896,70 @@ def calculate_max_parallel_requests(
return None
def _is_region_eu(model_region: str) -> bool:
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
for region in EU_Regions:
if "europe" in model_region.lower():
def _get_model_region(
custom_llm_provider: str, litellm_params: LiteLLM_Params
) -> Optional[str]:
"""
Return the region for a model, for a given provider
"""
if custom_llm_provider == "vertex_ai":
# check 'vertex_location'
vertex_ai_location = (
litellm_params.vertex_location
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret("VERTEX_LOCATION")
)
if vertex_ai_location is not None and isinstance(vertex_ai_location, str):
return vertex_ai_location
elif custom_llm_provider == "bedrock":
aws_region_name = litellm_params.aws_region_name
if aws_region_name is not None:
return aws_region_name
elif custom_llm_provider == "watsonx":
watsonx_region_name = litellm_params.watsonx_region_name
if watsonx_region_name is not None:
return watsonx_region_name
return litellm_params.region_name
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
"""
Return true/false if a deployment is in the EU
"""
if litellm_params.region_name == "eu":
return True
## ELSE ##
"""
- get provider
- get provider regions
- return true if given region (get_provider_region) in eu region (config.get_eu_regions())
"""
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_params.model, litellm_params=litellm_params
)
model_region = _get_model_region(
custom_llm_provider=custom_llm_provider, litellm_params=litellm_params
)
if model_region is None:
return False
if custom_llm_provider == "azure":
eu_regions = litellm.AzureOpenAIConfig().get_eu_regions()
elif custom_llm_provider == "vertex_ai":
eu_regions = litellm.VertexAIConfig().get_eu_regions()
elif custom_llm_provider == "bedrock":
eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions()
elif custom_llm_provider == "watsonx":
eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions()
else:
return False
for region in eu_regions:
if region in model_region.lower():
return True
return False
@ -6314,8 +6385,23 @@ def get_llm_provider(
custom_llm_provider: Optional[str] = None,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
):
litellm_params: Optional[LiteLLM_Params] = None,
) -> Tuple[str, str, Optional[str], Optional[str]]:
"""
Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure'
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
"""
try:
## IF LITELLM PARAMS GIVEN ##
if litellm_params is not None:
assert (
custom_llm_provider is None and api_base is None and api_key is None
), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden."
custom_llm_provider = litellm_params.custom_llm_provider
api_base = litellm_params.api_base
api_key = litellm_params.api_key
dynamic_api_key = None
# check if llm provider provided
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
@ -6376,7 +6462,8 @@ def get_llm_provider(
api_base
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
or "https://api.mistral.ai/v1"
)
) # type: ignore
# if api_base does not end with /v1 we add it
if api_base is not None and not api_base.endswith(
"/v1"
@ -6399,10 +6486,30 @@ def get_llm_provider(
or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHER_AI_TOKEN")
)
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base
elif model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base
# check if api base is a known openai compatible endpoint
if api_base:
@ -6426,7 +6533,22 @@ def get_llm_provider(
elif endpoint == "api.deepseek.com/v1":
custom_llm_provider = "deepseek"
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
return model, custom_llm_provider, dynamic_api_key, api_base
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(
api_base
)
)
if dynamic_api_key is not None and not isinstance(
dynamic_api_key, str
):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
## openai - chatcompletion + text completion
@ -6517,6 +6639,16 @@ def get_llm_provider(
),
llm_provider="",
)
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base
except Exception as e:
if isinstance(e, litellm.exceptions.BadRequestError):