Merge pull request #706 from mc-marcocheng/router/fix_dict

Fix Router.set_model_list & Avoid overwriting litellm_params
This commit is contained in:
Krish Dholakia 2023-10-29 21:11:26 -07:00 committed by GitHub
commit 6e454a6ce9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,16 +1,17 @@
from typing import Union, List, Dict, Optional
from datetime import datetime
from typing import Dict, List, Optional, Union
import litellm
class Router:
class Router:
"""
Example usage:
from litellm import Router
model_list = [{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/<your-deployment-name>",
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/<your-deployment-name>",
"api_key": <your-api-key>,
"api_version": <your-api-version>,
"api_base": <your-api-base>
@ -23,16 +24,15 @@ class Router:
"""
model_names: List = []
cache_responses: bool = False
def __init__(self,
model_list: Optional[list]=None,
def __init__(self,
model_list: Optional[list] = None,
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_password: Optional[str] = None,
redis_password: Optional[str] = None,
cache_responses: bool = False) -> None:
if model_list:
self.model_list = model_list
self.model_names = [m["model_name"] for m in model_list]
if redis_host is not None and redis_port is not None and redis_password is not None:
self.set_model_list(model_list)
if redis_host is not None and redis_port is not None and redis_password is not None:
cache_config = {
'type': 'redis',
'host': redis_host,
@ -45,61 +45,55 @@ class Router:
}
self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing
if cache_responses:
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
self.cache_responses = cache_responses
litellm.success_callback = [self.deployment_callback]
def completion(self,
model: str,
messages: List[Dict[str, str]],
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
**kwargs):
**kwargs):
"""
Example usage:
Example usage:
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
"""
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data["messages"] = messages
data["caching"] = self.cache_responses
# call via litellm.completion()
return litellm.completion(**{**data, **kwargs})
# call via litellm.completion()
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
async def acompletion(self,
model: str,
messages: List[Dict[str, str]],
async def acompletion(self,
model: str,
messages: List[Dict[str, str]],
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
**kwargs):
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data["messages"] = messages
data["caching"] = self.cache_responses
return await litellm.acompletion(**{**data, **kwargs})
def text_completion(self,
model: str,
prompt: str,
return await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
def text_completion(self,
model: str,
prompt: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs):
messages=[{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data["prompt"] = prompt
data["caching"] = self.cache_responses
# call via litellm.completion()
return litellm.text_completion(**{**data, **kwargs})
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs})
def embedding(self,
def embedding(self,
model: str,
input: Union[str, List],
is_async: Optional[bool] = False,
@ -108,10 +102,8 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
data["input"] = input
data["caching"] = self.cache_responses
# call via litellm.embedding()
return litellm.embedding(**{**data, **kwargs})
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
async def aembedding(self,
model: str,
@ -122,14 +114,13 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
data["input"] = input
data["caching"] = self.cache_responses
return await litellm.aembedding(**{**data, **kwargs})
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
def set_model_list(self, model_list: list):
self.model_list = model_list
self.model_names = [m["model_name"] for m in model_list]
def get_model_names(self):
def get_model_names(self):
return self.model_names
def deployment_callback(
@ -146,21 +137,21 @@ class Router:
total_tokens = completion_response['usage']['total_tokens']
self._set_deployment_usage(model_name, total_tokens)
def get_available_deployment(self,
model: str,
messages: Optional[List[Dict[str, str]]]=None,
input: Optional[Union[str, List]]=None):
def get_available_deployment(self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None):
"""
Returns a deployment with the lowest TPM/RPM usage.
"""
# get list of potential deployments
potential_deployments = []
for item in self.model_list:
if item["model_name"] == model:
# get list of potential deployments
potential_deployments = []
for item in self.model_list:
if item["model_name"] == model:
potential_deployments.append(item)
# set first model as current model
deployment = potential_deployments[0]
deployment = potential_deployments[0]
# get model tpm, rpm limits
@ -170,7 +161,7 @@ class Router:
# get deployment current usage
current_tpm, current_rpm = self._get_deployment_usage(deployment_name=deployment["litellm_params"]["model"])
# get encoding
# get encoding
if messages:
token_count = litellm.token_counter(model=deployment["model_name"], messages=messages)
elif input:
@ -179,9 +170,9 @@ class Router:
else:
input_text = input
token_count = litellm.token_counter(model=deployment["model_name"], text=input_text)
# if at model limit, return lowest used
if current_tpm + token_count > tpm or current_rpm + 1 >= rpm:
if current_tpm + token_count > tpm or current_rpm + 1 >= rpm:
# -----------------------
# Find lowest used model
# ----------------------
@ -194,17 +185,17 @@ class Router:
if item_tpm == 0:
return item
elif item_tpm + token_count > item["tpm"] or item_rpm + 1 >= item["rpm"]:
elif item_tpm + token_count > item["tpm"] or item_rpm + 1 >= item["rpm"]:
continue
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
deployment = item
# if none, raise exception
if deployment is None:
# if none, raise exception
if deployment is None:
raise ValueError(f"No models available.")
# return model
# return model
return deployment
def _get_deployment_usage(
@ -224,24 +215,24 @@ class Router:
tpm = self.cache.get_cache(tpm_key)
rpm = self.cache.get_cache(rpm_key)
if tpm is None:
if tpm is None:
tpm = 0
if rpm is None:
if rpm is None:
rpm = 0
return int(tpm), int(rpm)
def increment(self, key: str, increment_value: int):
# get value
def increment(self, key: str, increment_value: int):
# get value
cached_value = self.cache.get_cache(key)
# update value
# update value
try:
cached_value = cached_value + increment_value
except:
except:
cached_value = increment_value
# save updated value
self.cache.add_cache(result=cached_value, cache_key=key)
def _set_deployment_usage(
self,
model_name: str,