Merge pull request #2788 from BerriAI/litellm_support_-_models

[Feat] Allow using model = * on proxy config.yaml
This commit is contained in:
Ishaan Jaff 2024-04-01 19:46:50 -07:00 committed by GitHub
commit c2b9799e42
3 changed files with 114 additions and 1 deletions

View file

@ -3100,6 +3100,12 @@ async def completion(
response = await llm_router.atext_completion(
**data, specific_deployment=True
)
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atext_completion(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atext_completion(**data)
else:
@ -3317,6 +3323,12 @@ async def chat_completion(
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
tasks.append(litellm.acompletion(**data))
else:
@ -3531,6 +3543,12 @@ async def embeddings(
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aembedding(**data, specific_deployment=True)
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aembedding(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aembedding(**data)
else:
@ -3676,6 +3694,12 @@ async def image_generation(
response = await llm_router.aimage_generation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aimage_generation(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aimage_generation(**data)
else:
@ -3830,6 +3854,12 @@ async def audio_transcriptions(
response = await llm_router.atranscription(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atranscription(**data)
else:
@ -3983,6 +4013,12 @@ async def moderations(
response = await llm_router.amoderation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.amoderation(**data)
else:

View file

@ -191,6 +191,8 @@ class Router:
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
if model_list:
model_list = copy.deepcopy(model_list)
self.set_model_list(model_list)
@ -252,7 +254,6 @@ class Router:
}
}
"""
### ROUTING SETUP ###
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(
@ -2078,6 +2079,11 @@ class Router:
),
)
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if model["model_name"] == "*":
self.default_deployment = model
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = model.get("litellm_params", {}).get("dataSources", [])
@ -2260,6 +2266,13 @@ class Router:
)
model = self.model_group_alias[model]
if model not in self.model_names and self.default_deployment is not None:
updated_deployment = copy.deepcopy(
self.default_deployment
) # self.default_deployment
updated_deployment["litellm_params"]["model"] = model
return updated_deployment
## get healthy deployments
### get all deployments
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]

View file

@ -448,3 +448,67 @@ def test_usage_based_routing():
), f"Assertion failed: 'chatgpt-high-tpm' does not have about 80% of the total requests in the weighted load balancer. Selection counts {selection_counts}"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_wildcard_openai_routing():
"""
Initialize router with *, all models go through * and use OPENAI_API_KEY
"""
try:
model_list = [
{
"model_name": "*",
"litellm_params": {
"model": "openai/*",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100,
},
]
router = Router(
model_list=model_list,
)
messages = [
{"content": "Tell me a joke.", "role": "user"},
]
selection_counts = defaultdict(int)
for _ in range(25):
response = await router.acompletion(
model="gpt-4",
messages=messages,
mock_response="good morning",
)
# print("response1", response)
selection_counts[response["model"]] += 1
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=messages,
mock_response="good morning",
)
# print("response2", response)
selection_counts[response["model"]] += 1
response = await router.acompletion(
model="gpt-4-turbo-preview",
messages=messages,
mock_response="good morning",
)
# print("response3", response)
# print("response", response)
selection_counts[response["model"]] += 1
assert selection_counts["gpt-4"] == 25
assert selection_counts["gpt-3.5-turbo"] == 25
assert selection_counts["gpt-4-turbo-preview"] == 25
except Exception as e:
pytest.fail(f"Error occurred: {e}")