fix(proxy_cli-and-utils.py): fixing how config file is read + infering llm_provider for known openai endpoints

This commit is contained in:
Krrish Dholakia 2023-10-10 20:53:02 -07:00
parent d99d6a99f1
commit d280a8c434
9 changed files with 170 additions and 29 deletions

View file

@ -1358,7 +1358,7 @@ def get_optional_params( # use the openai defaults
optional_params[k] = passed_params[k]
return optional_params
def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None):
try:
# check if llm provider provided
if custom_llm_provider:
@ -1370,6 +1370,13 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
model = model.split("/", 1)[1]
return model, custom_llm_provider
# check if api base is a known openai compatible endpoint
if api_base:
for endpoint in litellm.openai_compatible_endpoints:
if endpoint in api_base:
custom_llm_provider = "openai"
return model, custom_llm_provider
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
## openai - chatcompletion + text completion
if model in litellm.open_ai_chat_completion_models:
@ -1429,6 +1436,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
except Exception as e:
raise e
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
api_key = (dynamic_api_key or litellm.api_key)
# openai
@ -1503,6 +1511,7 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
get_secret("TOGETHER_AI_TOKEN")
)
return api_key
def get_max_tokens(model: str):
try:
return litellm.model_cost[model]
@ -2183,6 +2192,7 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str
)
```
"""
model, _ = get_llm_provider(model=model)
litellm.custom_prompt_dict[model] = {
"roles": roles,
"initial_prompt_value": initial_prompt_value,