forked from phoenix/litellm-mirror
feat(router.py): support region routing for bedrock, vertex ai, watsonx
This commit is contained in:
parent
ebc927f1c8
commit
6714854bb7
7 changed files with 187 additions and 12 deletions
|
@ -9,7 +9,7 @@ from litellm.utils import (
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
)
|
)
|
||||||
from typing import Callable, Optional, BinaryIO
|
from typing import Callable, Optional, BinaryIO, List
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
import litellm, json
|
import litellm, json
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
@ -105,6 +105,12 @@ class AzureOpenAIConfig(OpenAIConfig):
|
||||||
optional_params["azure_ad_token"] = value
|
optional_params["azure_ad_token"] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||||
|
"""
|
||||||
|
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||||
|
|
||||||
|
|
||||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
# azure_client_params = {
|
# azure_client_params = {
|
||||||
|
|
|
@ -52,6 +52,16 @@ class AmazonBedrockGlobalConfig:
|
||||||
optional_params[mapped_params[param]] = value
|
optional_params[mapped_params[param]] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://www.aws-services.info/bedrock.html
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"eu-west-1",
|
||||||
|
"eu-west-3",
|
||||||
|
"eu-central-1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AmazonTitanConfig:
|
class AmazonTitanConfig:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -198,6 +198,23 @@ class VertexAIConfig:
|
||||||
optional_params[mapped_params[param]] = value
|
optional_params[mapped_params[param]] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"europe-central2",
|
||||||
|
"europe-north1",
|
||||||
|
"europe-southwest1",
|
||||||
|
"europe-west1",
|
||||||
|
"europe-west2",
|
||||||
|
"europe-west3",
|
||||||
|
"europe-west4",
|
||||||
|
"europe-west6",
|
||||||
|
"europe-west8",
|
||||||
|
"europe-west9",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
|
@ -149,6 +149,15 @@ class IBMWatsonXAIConfig:
|
||||||
optional_params[mapped_params[param]] = value
|
optional_params[mapped_params[param]] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"eu-de",
|
||||||
|
"eu-gb",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
# handle anthropic prompts and amazon titan prompts
|
# handle anthropic prompts and amazon titan prompts
|
||||||
|
|
|
@ -2329,7 +2329,7 @@ class Router:
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_api_key = api_key
|
_api_key = api_key # type: ignore
|
||||||
if _api_key is not None and isinstance(_api_key, str):
|
if _api_key is not None and isinstance(_api_key, str):
|
||||||
# only show first 5 chars of api_key
|
# only show first 5 chars of api_key
|
||||||
_api_key = _api_key[:8] + "*" * 15
|
_api_key = _api_key[:8] + "*" * 15
|
||||||
|
@ -2953,7 +2953,7 @@ class Router:
|
||||||
):
|
):
|
||||||
# check if in allowed_model_region
|
# check if in allowed_model_region
|
||||||
if (
|
if (
|
||||||
_is_region_eu(model_region=_litellm_params["region_name"])
|
_is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
|
||||||
== False
|
== False
|
||||||
):
|
):
|
||||||
invalid_model_indices.append(idx)
|
invalid_model_indices.append(idx)
|
||||||
|
|
|
@ -132,6 +132,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
aws_access_key_id: Optional[str] = None
|
aws_access_key_id: Optional[str] = None
|
||||||
aws_secret_access_key: Optional[str] = None
|
aws_secret_access_key: Optional[str] = None
|
||||||
aws_region_name: Optional[str] = None
|
aws_region_name: Optional[str] = None
|
||||||
|
## IBM WATSONX ##
|
||||||
|
watsonx_region_name: Optional[str] = None
|
||||||
## CUSTOM PRICING ##
|
## CUSTOM PRICING ##
|
||||||
input_cost_per_token: Optional[float] = None
|
input_cost_per_token: Optional[float] = None
|
||||||
output_cost_per_token: Optional[float] = None
|
output_cost_per_token: Optional[float] = None
|
||||||
|
@ -161,6 +163,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
aws_secret_access_key: Optional[str] = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
aws_region_name: Optional[str] = None,
|
aws_region_name: Optional[str] = None,
|
||||||
|
## IBM WATSONX ##
|
||||||
|
watsonx_region_name: Optional[str] = None,
|
||||||
input_cost_per_token: Optional[float] = None,
|
input_cost_per_token: Optional[float] = None,
|
||||||
output_cost_per_token: Optional[float] = None,
|
output_cost_per_token: Optional[float] = None,
|
||||||
input_cost_per_second: Optional[float] = None,
|
input_cost_per_second: Optional[float] = None,
|
||||||
|
|
147
litellm/utils.py
147
litellm/utils.py
|
@ -107,7 +107,18 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||||
|
|
||||||
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO, Iterable
|
from typing import (
|
||||||
|
cast,
|
||||||
|
List,
|
||||||
|
Dict,
|
||||||
|
Union,
|
||||||
|
Optional,
|
||||||
|
Literal,
|
||||||
|
Any,
|
||||||
|
BinaryIO,
|
||||||
|
Iterable,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
from .caching import Cache
|
from .caching import Cache
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
@ -5880,13 +5891,70 @@ def calculate_max_parallel_requests(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _is_region_eu(model_region: str) -> bool:
|
def _get_model_region(
|
||||||
if model_region == "eu":
|
custom_llm_provider: str, litellm_params: LiteLLM_Params
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Return the region for a model, for a given provider
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "vertex_ai":
|
||||||
|
# check 'vertex_location'
|
||||||
|
vertex_ai_location = (
|
||||||
|
litellm_params.vertex_location
|
||||||
|
or litellm.vertex_location
|
||||||
|
or get_secret("VERTEXAI_LOCATION")
|
||||||
|
or get_secret("VERTEX_LOCATION")
|
||||||
|
)
|
||||||
|
if vertex_ai_location is not None and isinstance(vertex_ai_location, str):
|
||||||
|
return vertex_ai_location
|
||||||
|
elif custom_llm_provider == "bedrock":
|
||||||
|
aws_region_name = litellm_params.aws_region_name
|
||||||
|
if aws_region_name is not None:
|
||||||
|
return aws_region_name
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
watsonx_region_name = litellm_params.watsonx_region_name
|
||||||
|
if watsonx_region_name is not None:
|
||||||
|
return watsonx_region_name
|
||||||
|
return litellm_params.region_name
|
||||||
|
|
||||||
|
|
||||||
|
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
|
||||||
|
"""
|
||||||
|
Return true/false if a deployment is in the EU
|
||||||
|
"""
|
||||||
|
if litellm_params.region_name == "eu":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
|
## ELSE ##
|
||||||
for region in EU_Regions:
|
"""
|
||||||
if "europe" in model_region.lower():
|
- get provider
|
||||||
|
- get provider regions
|
||||||
|
- return true if given region (get_provider_region) in eu region (config.get_eu_regions())
|
||||||
|
"""
|
||||||
|
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||||
|
model=litellm_params.model, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
model_region = _get_model_region(
|
||||||
|
custom_llm_provider=custom_llm_provider, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_region is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if custom_llm_provider == "azure":
|
||||||
|
eu_regions = litellm.AzureOpenAIConfig().get_eu_regions()
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
eu_regions = litellm.VertexAIConfig().get_eu_regions()
|
||||||
|
elif custom_llm_provider == "bedrock":
|
||||||
|
eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions()
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions()
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for region in eu_regions:
|
||||||
|
if region in model_region.lower():
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -6312,8 +6380,23 @@ def get_llm_provider(
|
||||||
custom_llm_provider: Optional[str] = None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
):
|
litellm_params: Optional[LiteLLM_Params] = None,
|
||||||
|
) -> Tuple[str, str, Optional[str], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure'
|
||||||
|
|
||||||
|
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
|
## IF LITELLM PARAMS GIVEN ##
|
||||||
|
if litellm_params is not None:
|
||||||
|
assert (
|
||||||
|
custom_llm_provider is None and api_base is None and api_key is None
|
||||||
|
), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden."
|
||||||
|
custom_llm_provider = litellm_params.custom_llm_provider
|
||||||
|
api_base = litellm_params.api_base
|
||||||
|
api_key = litellm_params.api_key
|
||||||
|
|
||||||
dynamic_api_key = None
|
dynamic_api_key = None
|
||||||
# check if llm provider provided
|
# check if llm provider provided
|
||||||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||||
|
@ -6374,7 +6457,8 @@ def get_llm_provider(
|
||||||
api_base
|
api_base
|
||||||
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
|
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
|
||||||
or "https://api.mistral.ai/v1"
|
or "https://api.mistral.ai/v1"
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
# if api_base does not end with /v1 we add it
|
# if api_base does not end with /v1 we add it
|
||||||
if api_base is not None and not api_base.endswith(
|
if api_base is not None and not api_base.endswith(
|
||||||
"/v1"
|
"/v1"
|
||||||
|
@ -6397,10 +6481,30 @@ def get_llm_provider(
|
||||||
or get_secret("TOGETHERAI_API_KEY")
|
or get_secret("TOGETHERAI_API_KEY")
|
||||||
or get_secret("TOGETHER_AI_TOKEN")
|
or get_secret("TOGETHER_AI_TOKEN")
|
||||||
)
|
)
|
||||||
|
if api_base is not None and not isinstance(api_base, str):
|
||||||
|
raise Exception(
|
||||||
|
"api base needs to be a string. api_base={}".format(api_base)
|
||||||
|
)
|
||||||
|
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||||
|
raise Exception(
|
||||||
|
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||||
|
dynamic_api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||||
custom_llm_provider = model.split("/", 1)[0]
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
model = model.split("/", 1)[1]
|
model = model.split("/", 1)[1]
|
||||||
|
if api_base is not None and not isinstance(api_base, str):
|
||||||
|
raise Exception(
|
||||||
|
"api base needs to be a string. api_base={}".format(api_base)
|
||||||
|
)
|
||||||
|
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||||
|
raise Exception(
|
||||||
|
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||||
|
dynamic_api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
# check if api base is a known openai compatible endpoint
|
# check if api base is a known openai compatible endpoint
|
||||||
if api_base:
|
if api_base:
|
||||||
|
@ -6424,7 +6528,22 @@ def get_llm_provider(
|
||||||
elif endpoint == "api.deepseek.com/v1":
|
elif endpoint == "api.deepseek.com/v1":
|
||||||
custom_llm_provider = "deepseek"
|
custom_llm_provider = "deepseek"
|
||||||
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
|
||||||
|
if api_base is not None and not isinstance(api_base, str):
|
||||||
|
raise Exception(
|
||||||
|
"api base needs to be a string. api_base={}".format(
|
||||||
|
api_base
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if dynamic_api_key is not None and not isinstance(
|
||||||
|
dynamic_api_key, str
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||||
|
dynamic_api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore
|
||||||
|
|
||||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||||
## openai - chatcompletion + text completion
|
## openai - chatcompletion + text completion
|
||||||
|
@ -6515,6 +6634,16 @@ def get_llm_provider(
|
||||||
),
|
),
|
||||||
llm_provider="",
|
llm_provider="",
|
||||||
)
|
)
|
||||||
|
if api_base is not None and not isinstance(api_base, str):
|
||||||
|
raise Exception(
|
||||||
|
"api base needs to be a string. api_base={}".format(api_base)
|
||||||
|
)
|
||||||
|
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||||
|
raise Exception(
|
||||||
|
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||||
|
dynamic_api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, litellm.exceptions.BadRequestError):
|
if isinstance(e, litellm.exceptions.BadRequestError):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue