Merge pull request #3582 from BerriAI/litellm_explicit_region_name_setting

feat(router.py): allow setting model_region in litellm_params
This commit is contained in:
Krish Dholakia 2024-05-11 11:36:22 -07:00 committed by GitHub
commit 86d0c0ae4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 405 additions and 95 deletions

View file

@ -151,7 +151,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
}' }'
``` ```
## Advanced - Context Window Fallbacks ## Advanced - Context Window Fallbacks (Pre-Call Checks + Fallbacks)
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**. **Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
@ -232,16 +232,16 @@ model_list:
- model_name: gpt-3.5-turbo-small - model_name: gpt-3.5-turbo-small
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview" api_version: "2023-07-01-preview"
model_info: model_info:
base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL base_model: azure/gpt-4-1106-preview # 2. 👈 (azure-only) SET BASE MODEL
- model_name: gpt-3.5-turbo-large - model_name: gpt-3.5-turbo-large
litellm_params: litellm_params:
model: gpt-3.5-turbo-1106 model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: claude-opus - model_name: claude-opus
litellm_params: litellm_params:
@ -287,6 +287,69 @@ print(response)
</Tabs> </Tabs>
## Advanced - EU-Region Filtering (Pre-Call Checks)
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
Set 'region_name' of deployment.
**Note:** LiteLLM can automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
**1. Set Config**
```yaml
router_settings:
enable_pre_call_checks: true # 1. Enable pre-call checks
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-07-01-preview"
region_name: "eu" # 👈 SET EU-REGION
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
- model_name: gemini-pro
litellm_params:
model: vertex_ai/gemini-pro-1.5
vertex_project: adroit-crow-1234
vertex_location: us-east1 # 👈 AUTOMATICALLY INFERS 'region_name'
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
**3. Test it!**
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.with_raw_response.create(
model="gpt-3.5-turbo",
messages = [{"role": "user", "content": "Who was Alexander?"}]
)
print(response)
print(f"response.headers.get('x-litellm-model-api-base')")
```
## Advanced - Custom Timeouts, Stream Timeouts - Per Model ## Advanced - Custom Timeouts, Stream Timeouts - Per Model
For each model you can set `timeout` & `stream_timeout` under `litellm_params` For each model you can set `timeout` & `stream_timeout` under `litellm_params`
```yaml ```yaml

View file

@ -879,13 +879,11 @@ router = Router(model_list: Optional[list] = None,
cache_responses=True) cache_responses=True)
``` ```
## Pre-Call Checks (Context Window) ## Pre-Call Checks (Context Window, EU-Regions)
Enable pre-call checks to filter out: Enable pre-call checks to filter out:
1. deployments with context window limit < messages for a call. 1. deployments with context window limit < messages for a call.
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[ 2. deployments outside of eu-region
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
])`)
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
@ -900,10 +898,14 @@ router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set t
**2. Set Model List** **2. Set Model List**
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`. For context window checks on azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
<Tabs> For 'eu-region' filtering, Set 'region_name' of deployment.
<TabItem value="same-group" label="Same Group">
**Note:** We automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
[**See Code**](https://github.com/BerriAI/litellm/blob/d33e49411d6503cb634f9652873160cd534dec96/litellm/router.py#L2958)
```python ```python
model_list = [ model_list = [
@ -914,10 +916,9 @@ model_list = [
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
}, "region_name": "eu" # 👈 SET 'EU' REGION NAME
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL "base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
} },
}, },
{ {
"model_name": "gpt-3.5-turbo", # model group name "model_name": "gpt-3.5-turbo", # model group name
@ -926,54 +927,26 @@ model_list = [
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
}, },
}, },
{
"model_name": "gemini-pro",
"litellm_params: {
"model": "vertex_ai/gemini-pro-1.5",
"vertex_project": "adroit-crow-1234",
"vertex_location": "us-east1" # 👈 AUTOMATICALLY INFERS 'region_name'
}
}
] ]
router = Router(model_list=model_list, enable_pre_call_checks=True) router = Router(model_list=model_list, enable_pre_call_checks=True)
``` ```
</TabItem>
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
```python
model_list = [
{
"model_name": "gpt-3.5-turbo-small", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
}
},
{
"model_name": "gpt-3.5-turbo-large", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "claude-opus",
"litellm_params": { call
"model": "claude-3-opus-20240229",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
```
</TabItem>
</Tabs>
**3. Test it!** **3. Test it!**
<Tabs>
<TabItem value="context-window-check" label="Context Window Check">
```python ```python
""" """
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k) - Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
@ -983,7 +956,6 @@ router = Router(model_list=model_list, enable_pre_call_checks=True, context_wind
from litellm import Router from litellm import Router
import os import os
try:
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", # model group name "model_name": "gpt-3.5-turbo", # model group name
@ -992,6 +964,7 @@ model_list = [
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
"base_model": "azure/gpt-35-turbo",
}, },
"model_info": { "model_info": {
"base_model": "azure/gpt-35-turbo", "base_model": "azure/gpt-35-turbo",
@ -1021,6 +994,59 @@ response = router.completion(
print(f"response: {response}") print(f"response: {response}")
``` ```
</TabItem> </TabItem>
<TabItem value="eu-region-check" label="EU Region Check">
```python
"""
- Give 2 gpt-3.5-turbo deployments, in eu + non-eu regions
- Make a call
- Assert it picks the eu-region model
"""
from litellm import Router
import os
model_list = [
{
"model_name": "gpt-3.5-turbo", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"region_name": "eu"
},
"model_info": {
"id": "1"
}
},
{
"model_name": "gpt-3.5-turbo", # model group name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {
"id": "2"
}
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True)
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Who was Alexander?"}],
)
print(f"response: {response}")
print(f"response id: {response._hidden_params['model_id']}")
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="proxy" label="Proxy"> <TabItem value="proxy" label="Proxy">
:::info :::info

View file

@ -102,6 +102,9 @@ blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
################## ##################
### PREVIEW FEATURES ###
enable_preview_features: bool = False
##################
logging: bool = True logging: bool = True
caching: bool = ( caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648

View file

@ -10,7 +10,7 @@ from litellm.utils import (
TranscriptionResponse, TranscriptionResponse,
get_secret, get_secret,
) )
from typing import Callable, Optional, BinaryIO from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
import httpx # type: ignore import httpx # type: ignore
@ -107,6 +107,12 @@ class AzureOpenAIConfig(OpenAIConfig):
optional_params["azure_ad_token"] = value optional_params["azure_ad_token"] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
"""
return ["europe", "sweden", "switzerland", "france", "uk"]
def select_azure_base_url_or_endpoint(azure_client_params: dict): def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = { # azure_client_params = {

View file

@ -52,6 +52,16 @@ class AmazonBedrockGlobalConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.aws-services.info/bedrock.html
"""
return [
"eu-west-1",
"eu-west-3",
"eu-central-1",
]
class AmazonTitanConfig: class AmazonTitanConfig:
""" """

View file

@ -198,6 +198,23 @@ class VertexAIConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
"""
return [
"europe-central2",
"europe-north1",
"europe-southwest1",
"europe-west1",
"europe-west2",
"europe-west3",
"europe-west4",
"europe-west6",
"europe-west8",
"europe-west9",
]
import asyncio import asyncio

View file

@ -163,6 +163,15 @@ class IBMWatsonXAIConfig:
optional_params[mapped_params[param]] = value optional_params[mapped_params[param]] = value
return optional_params return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
"""
return [
"eu-de",
"eu-gb",
]
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts

View file

@ -1,25 +1,13 @@
model_list: model_list:
- litellm_params: - litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ api_base: os.environ/AZURE_API_BASE
api_key: my-fake-key api_key: os.environ/AZURE_API_KEY
model: openai/my-fake-model api_version: 2023-07-01-preview
model_name: fake-openai-endpoint model: azure/azure-embedding-model
- litellm_params: model_info:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ base_model: text-embedding-ada-002
api_key: my-fake-key-2 mode: embedding
model: openai/my-fake-model-2 model_name: text-embedding-ada-002
model_name: fake-openai-endpoint
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key-3
model: openai/my-fake-model-3
model_name: fake-openai-endpoint
- model_name: gpt-4
litellm_params:
model: gpt-3.5-turbo
- litellm_params:
model: together_ai/codellama/CodeLlama-13b-Instruct-hf
model_name: CodeLlama-13b-Instruct
router_settings: router_settings:
redis_host: redis redis_host: redis
@ -28,6 +16,7 @@ router_settings:
litellm_settings: litellm_settings:
set_verbose: True set_verbose: True
enable_preview_features: true
# service_callback: ["prometheus_system"] # service_callback: ["prometheus_system"]
# success_callback: ["prometheus"] # success_callback: ["prometheus"]
# failure_callback: ["prometheus"] # failure_callback: ["prometheus"]

View file

@ -2339,7 +2339,7 @@ class Router:
) # cache for 1 hr ) # cache for 1 hr
else: else:
_api_key = api_key _api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str): if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key # only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15 _api_key = _api_key[:8] + "*" * 15
@ -2567,23 +2567,25 @@ class Router:
# init OpenAI, Azure clients # init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True)) self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model) # set region (if azure model) ## PREVIEW FEATURE ##
_auto_infer_region = os.environ.get("AUTO_INFER_REGION", False) if litellm.enable_preview_features == True:
if _auto_infer_region == True or _auto_infer_region == "True":
print("Auto inferring region") # noqa print("Auto inferring region") # noqa
""" """
Hiding behind a feature flag Hiding behind a feature flag
When there is a large amount of LLM deployments this makes startup times blow up When there is a large amount of LLM deployments this makes startup times blow up
""" """
try: try:
if "azure" in deployment.litellm_params.model: if (
"azure" in deployment.litellm_params.model
and deployment.litellm_params.region_name is None
):
region = litellm.utils.get_model_region( region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None litellm_params=deployment.litellm_params, mode=None
) )
deployment.litellm_params.region_name = region deployment.litellm_params.region_name = region
except Exception as e: except Exception as e:
verbose_router_logger.error( verbose_router_logger.debug(
"Unable to get the region for azure model - {}, {}".format( "Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e) deployment.litellm_params.model, str(e)
) )
@ -2961,7 +2963,7 @@ class Router:
): ):
# check if in allowed_model_region # check if in allowed_model_region
if ( if (
_is_region_eu(model_region=_litellm_params["region_name"]) _is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
== False == False
): ):
invalid_model_indices.append(idx) invalid_model_indices.append(idx)

View file

@ -687,6 +687,55 @@ def test_router_context_window_check_pre_call_check_out_group():
pytest.fail(f"Got unexpected exception on router! - {str(e)}") pytest.fail(f"Got unexpected exception on router! - {str(e)}")
@pytest.mark.parametrize("allowed_model_region", ["eu", None])
def test_router_region_pre_call_check(allowed_model_region):
"""
If region based routing set
- check if only model in allowed region is allowed by '_pre_call_checks'
"""
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"base_model": "azure/gpt-35-turbo",
"region_name": "eu",
},
"model_info": {"id": "1"},
},
{
"model_name": "gpt-3.5-turbo-large", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {"id": "2"},
},
]
router = Router(model_list=model_list, enable_pre_call_checks=True)
_healthy_deployments = router._pre_call_checks(
model="gpt-3.5-turbo",
healthy_deployments=model_list,
messages=[{"role": "user", "content": "Hey!"}],
allowed_model_region=allowed_model_region,
)
if allowed_model_region is None:
assert len(_healthy_deployments) == 2
else:
assert len(_healthy_deployments) == 1, "No models selected as healthy"
assert (
_healthy_deployments[0]["model_info"]["id"] == "1"
), "Incorrect model id picked. Got id={}, expected id=1".format(
_healthy_deployments[0]["model_info"]["id"]
)
### FUNCTION CALLING ### FUNCTION CALLING

View file

@ -132,6 +132,8 @@ class GenericLiteLLMParams(BaseModel):
aws_access_key_id: Optional[str] = None aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None aws_secret_access_key: Optional[str] = None
aws_region_name: Optional[str] = None aws_region_name: Optional[str] = None
## IBM WATSONX ##
watsonx_region_name: Optional[str] = None
## CUSTOM PRICING ## ## CUSTOM PRICING ##
input_cost_per_token: Optional[float] = None input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None output_cost_per_token: Optional[float] = None
@ -161,6 +163,8 @@ class GenericLiteLLMParams(BaseModel):
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None, aws_region_name: Optional[str] = None,
## IBM WATSONX ##
watsonx_region_name: Optional[str] = None,
input_cost_per_token: Optional[float] = None, input_cost_per_token: Optional[float] = None,
output_cost_per_token: Optional[float] = None, output_cost_per_token: Optional[float] = None,
input_cost_per_second: Optional[float] = None, input_cost_per_second: Optional[float] = None,

View file

@ -110,7 +110,18 @@ try:
except Exception as e: except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}") verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO, Iterable from typing import (
cast,
List,
Dict,
Union,
Optional,
Literal,
Any,
BinaryIO,
Iterable,
Tuple,
)
from .caching import Cache from .caching import Cache
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -5885,10 +5896,70 @@ def calculate_max_parallel_requests(
return None return None
def _is_region_eu(model_region: str) -> bool: def _get_model_region(
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"] custom_llm_provider: str, litellm_params: LiteLLM_Params
for region in EU_Regions: ) -> Optional[str]:
if "europe" in model_region.lower(): """
Return the region for a model, for a given provider
"""
if custom_llm_provider == "vertex_ai":
# check 'vertex_location'
vertex_ai_location = (
litellm_params.vertex_location
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret("VERTEX_LOCATION")
)
if vertex_ai_location is not None and isinstance(vertex_ai_location, str):
return vertex_ai_location
elif custom_llm_provider == "bedrock":
aws_region_name = litellm_params.aws_region_name
if aws_region_name is not None:
return aws_region_name
elif custom_llm_provider == "watsonx":
watsonx_region_name = litellm_params.watsonx_region_name
if watsonx_region_name is not None:
return watsonx_region_name
return litellm_params.region_name
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
"""
Return true/false if a deployment is in the EU
"""
if litellm_params.region_name == "eu":
return True
## ELSE ##
"""
- get provider
- get provider regions
- return true if given region (get_provider_region) in eu region (config.get_eu_regions())
"""
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_params.model, litellm_params=litellm_params
)
model_region = _get_model_region(
custom_llm_provider=custom_llm_provider, litellm_params=litellm_params
)
if model_region is None:
return False
if custom_llm_provider == "azure":
eu_regions = litellm.AzureOpenAIConfig().get_eu_regions()
elif custom_llm_provider == "vertex_ai":
eu_regions = litellm.VertexAIConfig().get_eu_regions()
elif custom_llm_provider == "bedrock":
eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions()
elif custom_llm_provider == "watsonx":
eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions()
else:
return False
for region in eu_regions:
if region in model_region.lower():
return True return True
return False return False
@ -6314,8 +6385,23 @@ def get_llm_provider(
custom_llm_provider: Optional[str] = None, custom_llm_provider: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
): litellm_params: Optional[LiteLLM_Params] = None,
) -> Tuple[str, str, Optional[str], Optional[str]]:
"""
Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure'
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
"""
try: try:
## IF LITELLM PARAMS GIVEN ##
if litellm_params is not None:
assert (
custom_llm_provider is None and api_base is None and api_key is None
), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden."
custom_llm_provider = litellm_params.custom_llm_provider
api_base = litellm_params.api_base
api_key = litellm_params.api_key
dynamic_api_key = None dynamic_api_key = None
# check if llm provider provided # check if llm provider provided
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere # AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
@ -6376,7 +6462,8 @@ def get_llm_provider(
api_base api_base
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
or "https://api.mistral.ai/v1" or "https://api.mistral.ai/v1"
) ) # type: ignore
# if api_base does not end with /v1 we add it # if api_base does not end with /v1 we add it
if api_base is not None and not api_base.endswith( if api_base is not None and not api_base.endswith(
"/v1" "/v1"
@ -6399,10 +6486,30 @@ def get_llm_provider(
or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHER_AI_TOKEN") or get_secret("TOGETHER_AI_TOKEN")
) )
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base return model, custom_llm_provider, dynamic_api_key, api_base
elif model.split("/", 1)[0] in litellm.provider_list: elif model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0] custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1] model = model.split("/", 1)[1]
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base return model, custom_llm_provider, dynamic_api_key, api_base
# check if api base is a known openai compatible endpoint # check if api base is a known openai compatible endpoint
if api_base: if api_base:
@ -6426,7 +6533,22 @@ def get_llm_provider(
elif endpoint == "api.deepseek.com/v1": elif endpoint == "api.deepseek.com/v1":
custom_llm_provider = "deepseek" custom_llm_provider = "deepseek"
dynamic_api_key = get_secret("DEEPSEEK_API_KEY") dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
return model, custom_llm_provider, dynamic_api_key, api_base
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(
api_base
)
)
if dynamic_api_key is not None and not isinstance(
dynamic_api_key, str
):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
## openai - chatcompletion + text completion ## openai - chatcompletion + text completion
@ -6517,6 +6639,16 @@ def get_llm_provider(
), ),
llm_provider="", llm_provider="",
) )
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base return model, custom_llm_provider, dynamic_api_key, api_base
except Exception as e: except Exception as e:
if isinstance(e, litellm.exceptions.BadRequestError): if isinstance(e, litellm.exceptions.BadRequestError):