From 47c1aa27c74ff1c30e28cbaa4d59e3481001b73f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 2 Dec 2023 21:24:28 -0800 Subject: [PATCH 1/4] fix(proxy_server.py): add litellm model cost map info to /model/info --- litellm/proxy/proxy_server.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f1824c95b..393d3d7bb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1086,8 +1086,21 @@ async def model_info(request: Request): all_models = config['model_list'] for model in all_models: + # get the model cost map info + ## make an api call + data = copy.deepcopy(model["litellm_params"]) + data["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] + data["max_tokens"] = 10 + response = await litellm.acompletion(**data) + litellm_model_info = litellm.model_cost.get(response["model"], {}) + model_info = model.get("model_info", {}) + for k, v in litellm_model_info.items(): + if k not in model_info: + model_info[k] = v + model["model_info"] = model_info # don't return the api key model["litellm_params"].pop("api_key", None) + # all_models = list(set([m["model_name"] for m in llm_model_list])) print_verbose(f"all_models: {all_models}") return dict( From ecddb852a256bef86dee7cb77511ad65587f833d Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Sat, 2 Dec 2023 21:30:33 -0800 Subject: [PATCH 2/4] (fix) proxy: pydantic error / warning message --- litellm/proxy/proxy_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 393d3d7bb..690c24da4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -201,6 +201,8 @@ class ModelParams(BaseModel): model_name: str litellm_params: dict model_info: Optional[dict] + class Config: + protected_namespaces = () class GenerateKeyRequest(BaseModel): duration: str = "1h" From add4dfc5287aa81d36b72b23aa021d736ed0246c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 2 Dec 2023 21:33:47 -0800 Subject: [PATCH 3/4] fix(proxy_server.py): support model info augmenting for azure models --- litellm/proxy/proxy_server.py | 5 ++++- litellm/utils.py | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 690c24da4..c9ba4c215 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1093,8 +1093,11 @@ async def model_info(request: Request): data = copy.deepcopy(model["litellm_params"]) data["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] data["max_tokens"] = 10 + print(f"data going to litellm acompletion: {data}") response = await litellm.acompletion(**data) - litellm_model_info = litellm.model_cost.get(response["model"], {}) + response_model = response["model"] + print(f"response model: {response_model}; response - {response}") + litellm_model_info = litellm.get_model_info(response_model) model_info = model.get("model_info", {}) for k, v in litellm_model_info.items(): if k not in model_info: diff --git a/litellm/utils.py b/litellm/utils.py index bed6d1cda..280a6342f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2703,6 +2703,13 @@ def get_model_info(model: str): except requests.exceptions.RequestException as e: return None try: + azure_llms = { + "gpt-35-turbo": "azure/gpt-3.5-turbo", + "gpt-35-turbo-16k": "azure/gpt-3.5-turbo-16k", + "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct" + } + if model in azure_llms: + model = azure_llms[model] if model in litellm.model_cost: return litellm.model_cost[model] model, custom_llm_provider, _, _ = get_llm_provider(model=model) From 69a449755050ed19eb274cb3d65a3c35ebfc548e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 2 Dec 2023 21:49:23 -0800 Subject: [PATCH 4/4] fix(main.py): accept user in embedding() --- litellm/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/main.py b/litellm/main.py index ca7b902e2..5c421a351 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1736,6 +1736,7 @@ def embedding( api_key: Optional[str] = None, api_type: Optional[str] = None, caching: bool=False, + user: Optional[str]=None, custom_llm_provider=None, litellm_call_id=None, litellm_logging_obj=None,