forked from phoenix/litellm-mirror
Merge pull request #3582 from BerriAI/litellm_explicit_region_name_setting
feat(router.py): allow setting model_region in litellm_params
This commit is contained in:
commit
86d0c0ae4e
12 changed files with 405 additions and 95 deletions
|
@ -151,7 +151,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
}'
|
||||
```
|
||||
|
||||
## Advanced - Context Window Fallbacks
|
||||
## Advanced - Context Window Fallbacks (Pre-Call Checks + Fallbacks)
|
||||
|
||||
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
||||
|
||||
|
@ -232,16 +232,16 @@ model_list:
|
|||
- model_name: gpt-3.5-turbo-small
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
|
||||
|
||||
- model_name: gpt-3.5-turbo-large
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-1106
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo-1106
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
- model_name: claude-opus
|
||||
litellm_params:
|
||||
|
@ -287,6 +287,69 @@ print(response)
|
|||
</Tabs>
|
||||
|
||||
|
||||
## Advanced - EU-Region Filtering (Pre-Call Checks)
|
||||
|
||||
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
||||
|
||||
Set 'region_name' of deployment.
|
||||
|
||||
**Note:** LiteLLM can automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
|
||||
|
||||
**1. Set Config**
|
||||
|
||||
```yaml
|
||||
router_settings:
|
||||
enable_pre_call_checks: true # 1. Enable pre-call checks
|
||||
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
region_name: "eu" # 👈 SET EU-REGION
|
||||
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-1106
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
- model_name: gemini-pro
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-pro-1.5
|
||||
vertex_project: adroit-crow-1234
|
||||
vertex_location: us-east1 # 👈 AUTOMATICALLY INFERS 'region_name'
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.with_raw_response.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages = [{"role": "user", "content": "Who was Alexander?"}]
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
print(f"response.headers.get('x-litellm-model-api-base')")
|
||||
```
|
||||
|
||||
## Advanced - Custom Timeouts, Stream Timeouts - Per Model
|
||||
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
||||
```yaml
|
||||
|
|
|
@ -879,13 +879,11 @@ router = Router(model_list: Optional[list] = None,
|
|||
cache_responses=True)
|
||||
```
|
||||
|
||||
## Pre-Call Checks (Context Window)
|
||||
## Pre-Call Checks (Context Window, EU-Regions)
|
||||
|
||||
Enable pre-call checks to filter out:
|
||||
1. deployments with context window limit < messages for a call.
|
||||
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[
|
||||
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
|
||||
])`)
|
||||
2. deployments outside of eu-region
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
@ -900,10 +898,14 @@ router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set t
|
|||
|
||||
**2. Set Model List**
|
||||
|
||||
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
|
||||
For context window checks on azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="same-group" label="Same Group">
|
||||
For 'eu-region' filtering, Set 'region_name' of deployment.
|
||||
|
||||
**Note:** We automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
|
||||
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/d33e49411d6503cb634f9652873160cd534dec96/litellm/router.py#L2958)
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
|
@ -914,10 +916,9 @@ model_list = [
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"model_info": {
|
||||
"region_name": "eu" # 👈 SET 'EU' REGION NAME
|
||||
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # model group name
|
||||
|
@ -926,54 +927,26 @@ model_list = [
|
|||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-pro",
|
||||
"litellm_params: {
|
||||
"model": "vertex_ai/gemini-pro-1.5",
|
||||
"vertex_project": "adroit-crow-1234",
|
||||
"vertex_location": "us-east1" # 👈 AUTOMATICALLY INFERS 'region_name'
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-small", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"model_info": {
|
||||
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-large", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-opus",
|
||||
"litellm_params": { call
|
||||
"model": "claude-3-opus-20240229",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="context-window-check" label="Context Window Check">
|
||||
|
||||
```python
|
||||
"""
|
||||
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
|
||||
|
@ -983,7 +956,6 @@ router = Router(model_list=model_list, enable_pre_call_checks=True, context_wind
|
|||
from litellm import Router
|
||||
import os
|
||||
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # model group name
|
||||
|
@ -992,6 +964,7 @@ model_list = [
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"base_model": "azure/gpt-35-turbo",
|
||||
},
|
||||
"model_info": {
|
||||
"base_model": "azure/gpt-35-turbo",
|
||||
|
@ -1021,6 +994,59 @@ response = router.completion(
|
|||
print(f"response: {response}")
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="eu-region-check" label="EU Region Check">
|
||||
|
||||
```python
|
||||
"""
|
||||
- Give 2 gpt-3.5-turbo deployments, in eu + non-eu regions
|
||||
- Make a call
|
||||
- Assert it picks the eu-region model
|
||||
"""
|
||||
|
||||
from litellm import Router
|
||||
import os
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"region_name": "eu"
|
||||
},
|
||||
"model_info": {
|
||||
"id": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # model group name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"model_info": {
|
||||
"id": "2"
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
||||
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Who was Alexander?"}],
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
print(f"response id: {response._hidden_params['model_id']}")
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
:::info
|
||||
|
|
|
@ -102,6 +102,9 @@ blocked_user_list: Optional[Union[str, List]] = None
|
|||
banned_keywords_list: Optional[Union[str, List]] = None
|
||||
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||
##################
|
||||
### PREVIEW FEATURES ###
|
||||
enable_preview_features: bool = False
|
||||
##################
|
||||
logging: bool = True
|
||||
caching: bool = (
|
||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
|
|
|
@ -10,7 +10,7 @@ from litellm.utils import (
|
|||
TranscriptionResponse,
|
||||
get_secret,
|
||||
)
|
||||
from typing import Callable, Optional, BinaryIO
|
||||
from typing import Callable, Optional, BinaryIO, List
|
||||
from litellm import OpenAIConfig
|
||||
import litellm, json
|
||||
import httpx # type: ignore
|
||||
|
@ -107,6 +107,12 @@ class AzureOpenAIConfig(OpenAIConfig):
|
|||
optional_params["azure_ad_token"] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
# azure_client_params = {
|
||||
|
|
|
@ -52,6 +52,16 @@ class AmazonBedrockGlobalConfig:
|
|||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"eu-west-1",
|
||||
"eu-west-3",
|
||||
"eu-central-1",
|
||||
]
|
||||
|
||||
|
||||
class AmazonTitanConfig:
|
||||
"""
|
||||
|
|
|
@ -198,6 +198,23 @@ class VertexAIConfig:
|
|||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
||||
"""
|
||||
return [
|
||||
"europe-central2",
|
||||
"europe-north1",
|
||||
"europe-southwest1",
|
||||
"europe-west1",
|
||||
"europe-west2",
|
||||
"europe-west3",
|
||||
"europe-west4",
|
||||
"europe-west6",
|
||||
"europe-west8",
|
||||
"europe-west9",
|
||||
]
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
|
|
|
@ -163,6 +163,15 @@ class IBMWatsonXAIConfig:
|
|||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"eu-de",
|
||||
"eu-gb",
|
||||
]
|
||||
|
||||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
|
|
|
@ -1,25 +1,13 @@
|
|||
model_list:
|
||||
- litellm_params:
|
||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||
api_key: my-fake-key
|
||||
model: openai/my-fake-model
|
||||
model_name: fake-openai-endpoint
|
||||
- litellm_params:
|
||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||
api_key: my-fake-key-2
|
||||
model: openai/my-fake-model-2
|
||||
model_name: fake-openai-endpoint
|
||||
- litellm_params:
|
||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||
api_key: my-fake-key-3
|
||||
model: openai/my-fake-model-3
|
||||
model_name: fake-openai-endpoint
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
- litellm_params:
|
||||
model: together_ai/codellama/CodeLlama-13b-Instruct-hf
|
||||
model_name: CodeLlama-13b-Instruct
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2023-07-01-preview
|
||||
model: azure/azure-embedding-model
|
||||
model_info:
|
||||
base_model: text-embedding-ada-002
|
||||
mode: embedding
|
||||
model_name: text-embedding-ada-002
|
||||
|
||||
router_settings:
|
||||
redis_host: redis
|
||||
|
@ -28,6 +16,7 @@ router_settings:
|
|||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
enable_preview_features: true
|
||||
# service_callback: ["prometheus_system"]
|
||||
# success_callback: ["prometheus"]
|
||||
# failure_callback: ["prometheus"]
|
||||
|
|
|
@ -2339,7 +2339,7 @@ class Router:
|
|||
) # cache for 1 hr
|
||||
|
||||
else:
|
||||
_api_key = api_key
|
||||
_api_key = api_key # type: ignore
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
# only show first 5 chars of api_key
|
||||
_api_key = _api_key[:8] + "*" * 15
|
||||
|
@ -2567,23 +2567,25 @@ class Router:
|
|||
# init OpenAI, Azure clients
|
||||
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||
|
||||
# set region (if azure model)
|
||||
_auto_infer_region = os.environ.get("AUTO_INFER_REGION", False)
|
||||
if _auto_infer_region == True or _auto_infer_region == "True":
|
||||
# set region (if azure model) ## PREVIEW FEATURE ##
|
||||
if litellm.enable_preview_features == True:
|
||||
print("Auto inferring region") # noqa
|
||||
"""
|
||||
Hiding behind a feature flag
|
||||
When there is a large amount of LLM deployments this makes startup times blow up
|
||||
"""
|
||||
try:
|
||||
if "azure" in deployment.litellm_params.model:
|
||||
if (
|
||||
"azure" in deployment.litellm_params.model
|
||||
and deployment.litellm_params.region_name is None
|
||||
):
|
||||
region = litellm.utils.get_model_region(
|
||||
litellm_params=deployment.litellm_params, mode=None
|
||||
)
|
||||
|
||||
deployment.litellm_params.region_name = region
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
verbose_router_logger.debug(
|
||||
"Unable to get the region for azure model - {}, {}".format(
|
||||
deployment.litellm_params.model, str(e)
|
||||
)
|
||||
|
@ -2961,7 +2963,7 @@ class Router:
|
|||
):
|
||||
# check if in allowed_model_region
|
||||
if (
|
||||
_is_region_eu(model_region=_litellm_params["region_name"])
|
||||
_is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
|
||||
== False
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
|
|
|
@ -687,6 +687,55 @@ def test_router_context_window_check_pre_call_check_out_group():
|
|||
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allowed_model_region", ["eu", None])
|
||||
def test_router_region_pre_call_check(allowed_model_region):
|
||||
"""
|
||||
If region based routing set
|
||||
- check if only model in allowed region is allowed by '_pre_call_checks'
|
||||
"""
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"base_model": "azure/gpt-35-turbo",
|
||||
"region_name": "eu",
|
||||
},
|
||||
"model_info": {"id": "1"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-large", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"model_info": {"id": "2"},
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
||||
|
||||
_healthy_deployments = router._pre_call_checks(
|
||||
model="gpt-3.5-turbo",
|
||||
healthy_deployments=model_list,
|
||||
messages=[{"role": "user", "content": "Hey!"}],
|
||||
allowed_model_region=allowed_model_region,
|
||||
)
|
||||
|
||||
if allowed_model_region is None:
|
||||
assert len(_healthy_deployments) == 2
|
||||
else:
|
||||
assert len(_healthy_deployments) == 1, "No models selected as healthy"
|
||||
assert (
|
||||
_healthy_deployments[0]["model_info"]["id"] == "1"
|
||||
), "Incorrect model id picked. Got id={}, expected id=1".format(
|
||||
_healthy_deployments[0]["model_info"]["id"]
|
||||
)
|
||||
|
||||
|
||||
### FUNCTION CALLING
|
||||
|
||||
|
||||
|
|
|
@ -132,6 +132,8 @@ class GenericLiteLLMParams(BaseModel):
|
|||
aws_access_key_id: Optional[str] = None
|
||||
aws_secret_access_key: Optional[str] = None
|
||||
aws_region_name: Optional[str] = None
|
||||
## IBM WATSONX ##
|
||||
watsonx_region_name: Optional[str] = None
|
||||
## CUSTOM PRICING ##
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
@ -161,6 +163,8 @@ class GenericLiteLLMParams(BaseModel):
|
|||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
## IBM WATSONX ##
|
||||
watsonx_region_name: Optional[str] = None,
|
||||
input_cost_per_token: Optional[float] = None,
|
||||
output_cost_per_token: Optional[float] = None,
|
||||
input_cost_per_second: Optional[float] = None,
|
||||
|
|
148
litellm/utils.py
148
litellm/utils.py
|
@ -110,7 +110,18 @@ try:
|
|||
except Exception as e:
|
||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||
|
||||
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO, Iterable
|
||||
from typing import (
|
||||
cast,
|
||||
List,
|
||||
Dict,
|
||||
Union,
|
||||
Optional,
|
||||
Literal,
|
||||
Any,
|
||||
BinaryIO,
|
||||
Iterable,
|
||||
Tuple,
|
||||
)
|
||||
from .caching import Cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
@ -5885,10 +5896,70 @@ def calculate_max_parallel_requests(
|
|||
return None
|
||||
|
||||
|
||||
def _is_region_eu(model_region: str) -> bool:
|
||||
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
for region in EU_Regions:
|
||||
if "europe" in model_region.lower():
|
||||
def _get_model_region(
|
||||
custom_llm_provider: str, litellm_params: LiteLLM_Params
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return the region for a model, for a given provider
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
# check 'vertex_location'
|
||||
vertex_ai_location = (
|
||||
litellm_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret("VERTEX_LOCATION")
|
||||
)
|
||||
if vertex_ai_location is not None and isinstance(vertex_ai_location, str):
|
||||
return vertex_ai_location
|
||||
elif custom_llm_provider == "bedrock":
|
||||
aws_region_name = litellm_params.aws_region_name
|
||||
if aws_region_name is not None:
|
||||
return aws_region_name
|
||||
elif custom_llm_provider == "watsonx":
|
||||
watsonx_region_name = litellm_params.watsonx_region_name
|
||||
if watsonx_region_name is not None:
|
||||
return watsonx_region_name
|
||||
return litellm_params.region_name
|
||||
|
||||
|
||||
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
|
||||
"""
|
||||
Return true/false if a deployment is in the EU
|
||||
"""
|
||||
if litellm_params.region_name == "eu":
|
||||
return True
|
||||
|
||||
## ELSE ##
|
||||
"""
|
||||
- get provider
|
||||
- get provider regions
|
||||
- return true if given region (get_provider_region) in eu region (config.get_eu_regions())
|
||||
"""
|
||||
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_params.model, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
model_region = _get_model_region(
|
||||
custom_llm_provider=custom_llm_provider, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if model_region is None:
|
||||
return False
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
eu_regions = litellm.AzureOpenAIConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
eu_regions = litellm.VertexAIConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "bedrock":
|
||||
eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "watsonx":
|
||||
eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions()
|
||||
else:
|
||||
return False
|
||||
|
||||
for region in eu_regions:
|
||||
if region in model_region.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -6314,8 +6385,23 @@ def get_llm_provider(
|
|||
custom_llm_provider: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> Tuple[str, str, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure'
|
||||
|
||||
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
|
||||
"""
|
||||
try:
|
||||
## IF LITELLM PARAMS GIVEN ##
|
||||
if litellm_params is not None:
|
||||
assert (
|
||||
custom_llm_provider is None and api_base is None and api_key is None
|
||||
), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden."
|
||||
custom_llm_provider = litellm_params.custom_llm_provider
|
||||
api_base = litellm_params.api_base
|
||||
api_key = litellm_params.api_key
|
||||
|
||||
dynamic_api_key = None
|
||||
# check if llm provider provided
|
||||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||
|
@ -6376,7 +6462,8 @@ def get_llm_provider(
|
|||
api_base
|
||||
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
|
||||
or "https://api.mistral.ai/v1"
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
# if api_base does not end with /v1 we add it
|
||||
if api_base is not None and not api_base.endswith(
|
||||
"/v1"
|
||||
|
@ -6399,10 +6486,30 @@ def get_llm_provider(
|
|||
or get_secret("TOGETHERAI_API_KEY")
|
||||
or get_secret("TOGETHER_AI_TOKEN")
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
# check if api base is a known openai compatible endpoint
|
||||
if api_base:
|
||||
|
@ -6426,7 +6533,22 @@ def get_llm_provider(
|
|||
elif endpoint == "api.deepseek.com/v1":
|
||||
custom_llm_provider = "deepseek"
|
||||
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(
|
||||
api_base
|
||||
)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(
|
||||
dynamic_api_key, str
|
||||
):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore
|
||||
|
||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||
## openai - chatcompletion + text completion
|
||||
|
@ -6517,6 +6639,16 @@ def get_llm_provider(
|
|||
),
|
||||
llm_provider="",
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.exceptions.BadRequestError):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue