mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #2788 from BerriAI/litellm_support_-_models
[Feat] Allow using model = * on proxy config.yaml
This commit is contained in:
commit
c2b9799e42
3 changed files with 114 additions and 1 deletions
|
@ -3100,6 +3100,12 @@ async def completion(
|
||||||
response = await llm_router.atext_completion(
|
response = await llm_router.atext_completion(
|
||||||
**data, specific_deployment=True
|
**data, specific_deployment=True
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and data["model"] not in router_model_names
|
||||||
|
and llm_router.default_deployment is not None
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.atext_completion(**data)
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
response = await litellm.atext_completion(**data)
|
response = await litellm.atext_completion(**data)
|
||||||
else:
|
else:
|
||||||
|
@ -3317,6 +3323,12 @@ async def chat_completion(
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
|
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and data["model"] not in router_model_names
|
||||||
|
and llm_router.default_deployment is not None
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
tasks.append(llm_router.acompletion(**data))
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
tasks.append(litellm.acompletion(**data))
|
tasks.append(litellm.acompletion(**data))
|
||||||
else:
|
else:
|
||||||
|
@ -3531,6 +3543,12 @@ async def embeddings(
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
response = await llm_router.aembedding(**data, specific_deployment=True)
|
response = await llm_router.aembedding(**data, specific_deployment=True)
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and data["model"] not in router_model_names
|
||||||
|
and llm_router.default_deployment is not None
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.aembedding(**data)
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
response = await litellm.aembedding(**data)
|
response = await litellm.aembedding(**data)
|
||||||
else:
|
else:
|
||||||
|
@ -3676,6 +3694,12 @@ async def image_generation(
|
||||||
response = await llm_router.aimage_generation(
|
response = await llm_router.aimage_generation(
|
||||||
**data
|
**data
|
||||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and data["model"] not in router_model_names
|
||||||
|
and llm_router.default_deployment is not None
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.aimage_generation(**data)
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
response = await litellm.aimage_generation(**data)
|
response = await litellm.aimage_generation(**data)
|
||||||
else:
|
else:
|
||||||
|
@ -3830,6 +3854,12 @@ async def audio_transcriptions(
|
||||||
response = await llm_router.atranscription(
|
response = await llm_router.atranscription(
|
||||||
**data
|
**data
|
||||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and data["model"] not in router_model_names
|
||||||
|
and llm_router.default_deployment is not None
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.atranscription(**data)
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
response = await litellm.atranscription(**data)
|
response = await litellm.atranscription(**data)
|
||||||
else:
|
else:
|
||||||
|
@ -3983,6 +4013,12 @@ async def moderations(
|
||||||
response = await llm_router.amoderation(
|
response = await llm_router.amoderation(
|
||||||
**data
|
**data
|
||||||
) # ensure this goes the llm_router, router will do the correct alias mapping
|
) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and data["model"] not in router_model_names
|
||||||
|
and llm_router.default_deployment is not None
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.amoderation(**data)
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
response = await litellm.amoderation(**data)
|
response = await litellm.amoderation(**data)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -191,6 +191,8 @@ class Router:
|
||||||
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
|
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
|
||||||
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||||
|
|
||||||
|
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
|
||||||
|
|
||||||
if model_list:
|
if model_list:
|
||||||
model_list = copy.deepcopy(model_list)
|
model_list = copy.deepcopy(model_list)
|
||||||
self.set_model_list(model_list)
|
self.set_model_list(model_list)
|
||||||
|
@ -252,7 +254,6 @@ class Router:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
### ROUTING SETUP ###
|
### ROUTING SETUP ###
|
||||||
if routing_strategy == "least-busy":
|
if routing_strategy == "least-busy":
|
||||||
self.leastbusy_logger = LeastBusyLoggingHandler(
|
self.leastbusy_logger = LeastBusyLoggingHandler(
|
||||||
|
@ -2078,6 +2079,11 @@ class Router:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if user is trying to use model_name == "*"
|
||||||
|
# this is a catch all model for their specific api key
|
||||||
|
if model["model_name"] == "*":
|
||||||
|
self.default_deployment = model
|
||||||
|
|
||||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||||
data_sources = model.get("litellm_params", {}).get("dataSources", [])
|
data_sources = model.get("litellm_params", {}).get("dataSources", [])
|
||||||
|
|
||||||
|
@ -2260,6 +2266,13 @@ class Router:
|
||||||
)
|
)
|
||||||
model = self.model_group_alias[model]
|
model = self.model_group_alias[model]
|
||||||
|
|
||||||
|
if model not in self.model_names and self.default_deployment is not None:
|
||||||
|
updated_deployment = copy.deepcopy(
|
||||||
|
self.default_deployment
|
||||||
|
) # self.default_deployment
|
||||||
|
updated_deployment["litellm_params"]["model"] = model
|
||||||
|
return updated_deployment
|
||||||
|
|
||||||
## get healthy deployments
|
## get healthy deployments
|
||||||
### get all deployments
|
### get all deployments
|
||||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||||
|
|
|
@ -448,3 +448,67 @@ def test_usage_based_routing():
|
||||||
), f"Assertion failed: 'chatgpt-high-tpm' does not have about 80% of the total requests in the weighted load balancer. Selection counts {selection_counts}"
|
), f"Assertion failed: 'chatgpt-high-tpm' does not have about 80% of the total requests in the weighted load balancer. Selection counts {selection_counts}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wildcard_openai_routing():
|
||||||
|
"""
|
||||||
|
Initialize router with *, all models go through * and use OPENAI_API_KEY
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/*",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 100,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"content": "Tell me a joke.", "role": "user"},
|
||||||
|
]
|
||||||
|
|
||||||
|
selection_counts = defaultdict(int)
|
||||||
|
for _ in range(25):
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=messages,
|
||||||
|
mock_response="good morning",
|
||||||
|
)
|
||||||
|
# print("response1", response)
|
||||||
|
|
||||||
|
selection_counts[response["model"]] += 1
|
||||||
|
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
mock_response="good morning",
|
||||||
|
)
|
||||||
|
# print("response2", response)
|
||||||
|
|
||||||
|
selection_counts[response["model"]] += 1
|
||||||
|
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-4-turbo-preview",
|
||||||
|
messages=messages,
|
||||||
|
mock_response="good morning",
|
||||||
|
)
|
||||||
|
# print("response3", response)
|
||||||
|
|
||||||
|
# print("response", response)
|
||||||
|
|
||||||
|
selection_counts[response["model"]] += 1
|
||||||
|
|
||||||
|
assert selection_counts["gpt-4"] == 25
|
||||||
|
assert selection_counts["gpt-3.5-turbo"] == 25
|
||||||
|
assert selection_counts["gpt-4-turbo-preview"] == 25
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue