mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): fix set_client init to check if custom_llm_provider is azure not if in model name
fixes issue where 'azure_ai/' was being init as azureopenai client
This commit is contained in:
parent
c0540e764d
commit
06b297a6e8
2 changed files with 27 additions and 2 deletions
|
@ -3090,7 +3090,7 @@ class Router:
|
||||||
if not, add it - https://github.com/BerriAI/litellm/issues/2279
|
if not, add it - https://github.com/BerriAI/litellm/issues/2279
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
is_azure_ai_studio_model == True
|
is_azure_ai_studio_model is True
|
||||||
and api_base is not None
|
and api_base is not None
|
||||||
and isinstance(api_base, str)
|
and isinstance(api_base, str)
|
||||||
and not api_base.endswith("/v1/")
|
and not api_base.endswith("/v1/")
|
||||||
|
@ -3174,7 +3174,7 @@ class Router:
|
||||||
organization = litellm.get_secret(organization_env_name)
|
organization = litellm.get_secret(organization_env_name)
|
||||||
litellm_params["organization"] = organization
|
litellm_params["organization"] = organization
|
||||||
|
|
||||||
if "azure" in model_name:
|
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||||
if api_base is None or not isinstance(api_base, str):
|
if api_base is None or not isinstance(api_base, str):
|
||||||
filtered_litellm_params = {
|
filtered_litellm_params = {
|
||||||
k: v
|
k: v
|
||||||
|
|
|
@ -70,6 +70,31 @@ def test_router_specific_model_via_id():
|
||||||
router.completion(model="1234", messages=[{"role": "user", "content": "Hey!"}])
|
router.completion(model="1234", messages=[{"role": "user", "content": "Hey!"}])
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_azure_ai_client_init():
|
||||||
|
|
||||||
|
_deployment = {
|
||||||
|
"model_name": "meta-llama-3-70b",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure_ai/Meta-Llama-3-70B-instruct",
|
||||||
|
"api_base": "my-fake-route",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
},
|
||||||
|
"model_info": {"id": "1234"},
|
||||||
|
}
|
||||||
|
router = Router(model_list=[_deployment])
|
||||||
|
|
||||||
|
_client = router._get_client(
|
||||||
|
deployment=_deployment,
|
||||||
|
client_type="async",
|
||||||
|
kwargs={"stream": False},
|
||||||
|
)
|
||||||
|
print(_client)
|
||||||
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||||
|
|
||||||
|
assert isinstance(_client, AsyncOpenAI)
|
||||||
|
assert not isinstance(_client, AsyncAzureOpenAI)
|
||||||
|
|
||||||
|
|
||||||
def test_router_sensitive_keys():
|
def test_router_sensitive_keys():
|
||||||
try:
|
try:
|
||||||
router = Router(
|
router = Router(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue