feat(proxy_server.py): add model router to proxy

This commit is contained in:
Krrish Dholakia 2023-10-18 17:40:01 -07:00
parent 8bb5637c38
commit 0c083e7a5c
3 changed files with 46 additions and 7 deletions

View file

@ -21,11 +21,13 @@ class Router:
router = Router(model_list=model_list)
"""
def __init__(self,
model_list: list,
model_list: Optional[list]=None,
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_password: Optional[str] = None) -> None:
self.model_list = model_list
if model_list:
self.model_list = model_list
self.model_names = [m["model_name"] for m in model_list]
if redis_host is not None and redis_port is not None and redis_password is not None:
cache_config = {
'type': 'redis',
@ -60,6 +62,23 @@ class Router:
data["messages"] = messages
# call via litellm.completion()
return litellm.completion(**data)
def text_completion(self,
model: str,
prompt: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs):
messages=[{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data["prompt"] = prompt
# call via litellm.completion()
return litellm.text_completion(**data)
def embedding(self,
model: str,
@ -74,6 +93,12 @@ class Router:
# call via litellm.embedding()
return litellm.embedding(**data)
def set_model_list(self, model_list: list):
self.model_list = model_list
def get_model_names(self):
return self.model_names
def deployment_callback(
self,
kwargs, # kwargs to completion