add VertexAITextEmbeddingConfig

This commit is contained in:
Ishaan Jaff 2024-08-30 12:53:43 -07:00
parent ea12519b98
commit 9f9b6e81a7
2 changed files with 13 additions and 7 deletions

View file

@ -78,15 +78,18 @@ class VertexAITextEmbeddingConfig(BaseModel):
} }
def get_supported_openai_params(self): def get_supported_openai_params(self):
return ["dimensions", "input_type"] return ["dimensions"]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(
self, non_default_params: dict, optional_params: dict, kwargs: dict
):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "dimensions": if param == "dimensions":
optional_params["output_dimensionality"] = value optional_params["output_dimensionality"] = value
if param == "input_type":
optional_params["task_type"] = value if "input_type" in kwargs:
return optional_params optional_params["task_type"] = kwargs.pop("input_type")
return optional_params, kwargs
def get_mapped_special_auth_params(self) -> dict: def get_mapped_special_auth_params(self) -> dict:
""" """

View file

@ -2621,8 +2621,11 @@ def get_optional_params_embeddings(
request_type="embeddings", request_type="embeddings",
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAITextEmbeddingConfig().map_openai_params( (
non_default_params=non_default_params, optional_params={} optional_params,
kwargs,
) = litellm.VertexAITextEmbeddingConfig().map_openai_params(
non_default_params=non_default_params, optional_params={}, kwargs=kwargs
) )
final_params = {**optional_params, **kwargs} final_params = {**optional_params, **kwargs}
return final_params return final_params