use correct vtx ai21 pricing

This commit is contained in:
Ishaan Jaff 2024-08-29 19:04:05 -07:00
parent 310e17c78b
commit dbdbf3d9a2
2 changed files with 7 additions and 1 deletions

View file

@ -358,6 +358,7 @@ vertex_code_text_models: List = []
vertex_embedding_models: List = []
vertex_anthropic_models: List = []
vertex_llama3_models: List = []
vertex_ai_ai21_models: List = []
vertex_mistral_models: List = []
ai21_models: List = []
nlp_cloud_models: List = []
@ -408,6 +409,9 @@ for key, value in model_cost.items():
elif value.get("litellm_provider") == "vertex_ai-mistral_models":
key = key.replace("vertex_ai/", "")
vertex_mistral_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-ai21_models":
key = key.replace("vertex_ai/", "")
vertex_ai_ai21_models.append(key)
elif value.get("litellm_provider") == "ai21":
ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud":

View file

@ -3267,7 +3267,7 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "vertex_ai" and model in litellm.ai21_models:
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
@ -5182,6 +5182,8 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
model = "meta/" + model
elif model + "@latest" in litellm.vertex_mistral_models:
model = model + "@latest"
elif model + "@latest" in litellm.vertex_ai_ai21_models:
model = model + "@latest"
##########################
if custom_llm_provider is None:
# Get custom_llm_provider