fix InfinityEmbeddingConfig

This commit is contained in:
Ishaan Jaff 2025-04-21 18:55:02 -07:00
parent 8c385a94aa
commit cf09a6770a

View file

@ -18,9 +18,17 @@ class InfinityEmbeddingConfig(BaseEmbeddingConfig):
"""
def __init__(self) -> None:
pass
pass
def get_complete_url(self, api_base: str | None, api_key: str | None, model: str, optional_params: dict, litellm_params: dict, stream: bool | None = None) -> str:
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
if api_base is None:
raise ValueError("api_base is required for Infinity embeddings")
# Remove trailing slashes and ensure clean base URL
@ -28,7 +36,7 @@ class InfinityEmbeddingConfig(BaseEmbeddingConfig):
if not api_base.endswith("/embeddings"):
api_base = f"{api_base}/embeddings"
return api_base
def validate_environment(
self,
headers: dict,
@ -51,17 +59,17 @@ class InfinityEmbeddingConfig(BaseEmbeddingConfig):
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def get_supported_openai_params(self, model: str) -> list:
return [
"encoding_format",
"modality",
"dimensions",
]
def map_openai_params(
self,
non_default_params: dict,
@ -81,7 +89,7 @@ class InfinityEmbeddingConfig(BaseEmbeddingConfig):
if "dimensions" in non_default_params:
optional_params["output_dimension"] = non_default_params["dimensions"]
return optional_params
def transform_embedding_request(
self,
model: str,
@ -130,4 +138,4 @@ class InfinityEmbeddingConfig(BaseEmbeddingConfig):
) -> BaseLLMException:
return InfinityError(
message=error_message, status_code=status_code, headers=headers
)
)