diff --git a/litellm/proxy/llm.py b/litellm/proxy/llm.py index 816ea3e85..581c0890d 100644 --- a/litellm/proxy/llm.py +++ b/litellm/proxy/llm.py @@ -113,7 +113,8 @@ def litellm_completion(data: Dict, user_max_tokens: Optional[int], user_api_base: Optional[str], user_headers: Optional[dict], - user_debug: bool): + user_debug: bool, + model_router: Optional[litellm.Router]): try: global debug debug = user_debug @@ -129,9 +130,15 @@ def litellm_completion(data: Dict, if user_headers: data["headers"] = user_headers if type == "completion": - response = litellm.text_completion(**data) + if data["model"] in model_router.get_model_names(): + model_router.text_completion(**data) + else: + response = litellm.text_completion(**data) elif type == "chat_completion": - response = litellm.completion(**data) + if data["model"] in model_router.get_model_names(): + model_router.completion(**data) + else: + response = litellm.completion(**data) if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses return StreamingResponse(data_generator(response), media_type='text/event-stream') print_verbose(f"response: {response}") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ccee7ffd0..c17c5e55a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -101,6 +101,7 @@ user_telemetry = True user_config = None user_headers = None local_logging = True # writes logs to a local api_log.json file for debugging +model_router = litellm.Router() config_filename = "litellm.secrets.toml" config_dir = os.getcwd() config_dir = appdirs.user_config_dir("litellm") @@ -213,6 +214,12 @@ def load_config(): if "model" in user_config: if user_model in user_config["model"]: model_config = user_config["model"][user_model] + model_list = [] + for model in user_config["model"]: + if "model_list" in user_config["model"][model]: + model_list.extend(user_config["model"][model]["model_list"]) + if len(model_list) > 0: + model_router.set_model_list(model_list=model_list) print_verbose(f"user_config: {user_config}") print_verbose(f"model_config: {model_config}") @@ -423,7 +430,7 @@ async def completion(request: Request): data = await request.json() return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, - user_debug=user_debug) + user_debug=user_debug, model_router=model_router) @router.post("/v1/chat/completions") @@ -433,7 +440,7 @@ async def chat_completion(request: Request): print_verbose(f"data passed in: {data}") return litellm_completion(data, type="chat_completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, - user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) + user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug, model_router=model_router) def print_cost_logs(): diff --git a/litellm/router.py b/litellm/router.py index 7cf81c508..ded21e98b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -21,11 +21,13 @@ class Router: router = Router(model_list=model_list) """ def __init__(self, - model_list: list, + model_list: Optional[list]=None, redis_host: Optional[str] = None, redis_port: Optional[int] = None, redis_password: Optional[str] = None) -> None: - self.model_list = model_list + if model_list: + self.model_list = model_list + self.model_names = [m["model_name"] for m in model_list] if redis_host is not None and redis_port is not None and redis_password is not None: cache_config = { 'type': 'redis', @@ -60,6 +62,23 @@ class Router: data["messages"] = messages # call via litellm.completion() return litellm.completion(**data) + + def text_completion(self, + model: str, + prompt: str, + is_retry: Optional[bool] = False, + is_fallback: Optional[bool] = False, + is_async: Optional[bool] = False, + **kwargs): + + messages=[{"role": "user", "content": prompt}] + # pick the one that is available (lowest TPM/RPM) + deployment = self.get_available_deployment(model=model, messages=messages) + + data = deployment["litellm_params"] + data["prompt"] = prompt + # call via litellm.completion() + return litellm.text_completion(**data) def embedding(self, model: str, @@ -74,6 +93,12 @@ class Router: # call via litellm.embedding() return litellm.embedding(**data) + def set_model_list(self, model_list: list): + self.model_list = model_list + + def get_model_names(self): + return self.model_names + def deployment_callback( self, kwargs, # kwargs to completion