forked from phoenix/litellm-mirror
feat(proxy_server.py): add model router to proxy
This commit is contained in:
parent
c8f8686d7c
commit
3a8c8f56d6
3 changed files with 46 additions and 7 deletions
|
@ -113,7 +113,8 @@ def litellm_completion(data: Dict,
|
||||||
user_max_tokens: Optional[int],
|
user_max_tokens: Optional[int],
|
||||||
user_api_base: Optional[str],
|
user_api_base: Optional[str],
|
||||||
user_headers: Optional[dict],
|
user_headers: Optional[dict],
|
||||||
user_debug: bool):
|
user_debug: bool,
|
||||||
|
model_router: Optional[litellm.Router]):
|
||||||
try:
|
try:
|
||||||
global debug
|
global debug
|
||||||
debug = user_debug
|
debug = user_debug
|
||||||
|
@ -129,9 +130,15 @@ def litellm_completion(data: Dict,
|
||||||
if user_headers:
|
if user_headers:
|
||||||
data["headers"] = user_headers
|
data["headers"] = user_headers
|
||||||
if type == "completion":
|
if type == "completion":
|
||||||
response = litellm.text_completion(**data)
|
if data["model"] in model_router.get_model_names():
|
||||||
|
model_router.text_completion(**data)
|
||||||
|
else:
|
||||||
|
response = litellm.text_completion(**data)
|
||||||
elif type == "chat_completion":
|
elif type == "chat_completion":
|
||||||
response = litellm.completion(**data)
|
if data["model"] in model_router.get_model_names():
|
||||||
|
model_router.completion(**data)
|
||||||
|
else:
|
||||||
|
response = litellm.completion(**data)
|
||||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
print_verbose(f"response: {response}")
|
print_verbose(f"response: {response}")
|
||||||
|
|
|
@ -101,6 +101,7 @@ user_telemetry = True
|
||||||
user_config = None
|
user_config = None
|
||||||
user_headers = None
|
user_headers = None
|
||||||
local_logging = True # writes logs to a local api_log.json file for debugging
|
local_logging = True # writes logs to a local api_log.json file for debugging
|
||||||
|
model_router = litellm.Router()
|
||||||
config_filename = "litellm.secrets.toml"
|
config_filename = "litellm.secrets.toml"
|
||||||
config_dir = os.getcwd()
|
config_dir = os.getcwd()
|
||||||
config_dir = appdirs.user_config_dir("litellm")
|
config_dir = appdirs.user_config_dir("litellm")
|
||||||
|
@ -213,6 +214,12 @@ def load_config():
|
||||||
if "model" in user_config:
|
if "model" in user_config:
|
||||||
if user_model in user_config["model"]:
|
if user_model in user_config["model"]:
|
||||||
model_config = user_config["model"][user_model]
|
model_config = user_config["model"][user_model]
|
||||||
|
model_list = []
|
||||||
|
for model in user_config["model"]:
|
||||||
|
if "model_list" in user_config["model"][model]:
|
||||||
|
model_list.extend(user_config["model"][model]["model_list"])
|
||||||
|
if len(model_list) > 0:
|
||||||
|
model_router.set_model_list(model_list=model_list)
|
||||||
|
|
||||||
print_verbose(f"user_config: {user_config}")
|
print_verbose(f"user_config: {user_config}")
|
||||||
print_verbose(f"model_config: {model_config}")
|
print_verbose(f"model_config: {model_config}")
|
||||||
|
@ -423,7 +430,7 @@ async def completion(request: Request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature,
|
return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature,
|
||||||
user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers,
|
user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers,
|
||||||
user_debug=user_debug)
|
user_debug=user_debug, model_router=model_router)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
|
@ -433,7 +440,7 @@ async def chat_completion(request: Request):
|
||||||
print_verbose(f"data passed in: {data}")
|
print_verbose(f"data passed in: {data}")
|
||||||
return litellm_completion(data, type="chat_completion", user_model=user_model,
|
return litellm_completion(data, type="chat_completion", user_model=user_model,
|
||||||
user_temperature=user_temperature, user_max_tokens=user_max_tokens,
|
user_temperature=user_temperature, user_max_tokens=user_max_tokens,
|
||||||
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
|
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug, model_router=model_router)
|
||||||
|
|
||||||
|
|
||||||
def print_cost_logs():
|
def print_cost_logs():
|
||||||
|
|
|
@ -21,11 +21,13 @@ class Router:
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list)
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_list: list,
|
model_list: Optional[list]=None,
|
||||||
redis_host: Optional[str] = None,
|
redis_host: Optional[str] = None,
|
||||||
redis_port: Optional[int] = None,
|
redis_port: Optional[int] = None,
|
||||||
redis_password: Optional[str] = None) -> None:
|
redis_password: Optional[str] = None) -> None:
|
||||||
self.model_list = model_list
|
if model_list:
|
||||||
|
self.model_list = model_list
|
||||||
|
self.model_names = [m["model_name"] for m in model_list]
|
||||||
if redis_host is not None and redis_port is not None and redis_password is not None:
|
if redis_host is not None and redis_port is not None and redis_password is not None:
|
||||||
cache_config = {
|
cache_config = {
|
||||||
'type': 'redis',
|
'type': 'redis',
|
||||||
|
@ -61,6 +63,23 @@ class Router:
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.completion(**data)
|
return litellm.completion(**data)
|
||||||
|
|
||||||
|
def text_completion(self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
is_retry: Optional[bool] = False,
|
||||||
|
is_fallback: Optional[bool] = False,
|
||||||
|
is_async: Optional[bool] = False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
messages=[{"role": "user", "content": prompt}]
|
||||||
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
|
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||||
|
|
||||||
|
data = deployment["litellm_params"]
|
||||||
|
data["prompt"] = prompt
|
||||||
|
# call via litellm.completion()
|
||||||
|
return litellm.text_completion(**data)
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
input: Union[str, List],
|
input: Union[str, List],
|
||||||
|
@ -74,6 +93,12 @@ class Router:
|
||||||
# call via litellm.embedding()
|
# call via litellm.embedding()
|
||||||
return litellm.embedding(**data)
|
return litellm.embedding(**data)
|
||||||
|
|
||||||
|
def set_model_list(self, model_list: list):
|
||||||
|
self.model_list = model_list
|
||||||
|
|
||||||
|
def get_model_names(self):
|
||||||
|
return self.model_names
|
||||||
|
|
||||||
def deployment_callback(
|
def deployment_callback(
|
||||||
self,
|
self,
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue