fix - return supported_openai_params from get_model_info

This commit is contained in:
Ishaan Jaff 2024-05-27 09:00:12 -07:00
parent b5f883ab74
commit 245990597e

View file

@ -7107,6 +7107,7 @@ def get_model_info(model: str) -> ModelInfo:
- output_cost_per_token (float): The cost per token for output.
- litellm_provider (str): The provider of the model (e.g., "openai").
- mode (str): The mode of the model (e.g., "chat" or "completion").
- supported_openai_params (List[str]): A list of supported OpenAI parameters for the model.
Raises:
Exception: If the model is not mapped yet.
@ -7118,9 +7119,11 @@ def get_model_info(model: str) -> ModelInfo:
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006,
"litellm_provider": "openai",
"mode": "chat"
"mode": "chat",
"supported_openai_params": ["temperature", "max_tokens", "top_p", "frequency_penalty", "presence_penalty"]
}
"""
supported_openai_params: Union[List[str], None] = []
def _get_max_position_embeddings(model_name):
# Construct the URL for the config.json file
@ -7148,9 +7151,18 @@ def get_model_info(model: str) -> ModelInfo:
azure_llms = litellm.azure_llms
if model in azure_llms:
model = azure_llms[model]
if model in litellm.model_cost:
return litellm.model_cost[model]
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
##########################
# Get custom_llm_provider
split_model, custom_llm_provider = model, ""
try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except:
pass
#########################
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model)
return {
@ -7159,15 +7171,26 @@ def get_model_info(model: str) -> ModelInfo:
"output_cost_per_token": 0,
"litellm_provider": "huggingface",
"mode": "chat",
"supported_openai_params": supported_openai_params,
}
else:
"""
Check if model in model cost map
Check if:
1. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost
2. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost
"""
if model in litellm.model_cost:
return litellm.model_cost[model]
_model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params
return _model_info
if split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params
return _model_info
else:
raise Exception()
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
)
except:
raise Exception(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"