router return get_deployment_by_model_group_name

This commit is contained in:
Ishaan Jaff 2024-07-15 19:27:12 -07:00
parent 865469e43f
commit e65daef572

View file

@ -3684,6 +3684,24 @@ class Router:
raise Exception("Model invalid format - {}".format(type(model)))
return None
def get_deployment_by_model_group_name(
self, model_group_name: str
) -> Optional[Deployment]:
"""
Returns -> Deployment or None
Raise Exception -> if model found in invalid format
"""
for model in self.model_list:
if model["model_name"] == model_group_name:
if isinstance(model, dict):
return Deployment(**model)
elif isinstance(model, Deployment):
return model
else:
raise Exception("Model Name invalid - {}".format(type(model)))
return None
def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
"""
For a given model id, return the model info (max tokens, input cost, output cost, etc.).