feat: Add rerank models and rerank API change (#3831)

# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
- Extend the model type to include rerank models.
- Implement `rerank()` method in inference router.
- Add `rerank_model_list` to `OpenAIMixin` to enable providers to
register and identify rerank models
- Update documentation.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
```
pytest tests/unit/providers/utils/inference/test_openai_mixin.py
```
This commit is contained in:
Jiayi Ni 2025-10-22 12:02:28 -07:00 committed by GitHub
parent f2598d30e6
commit bb1ebb3c6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 186 additions and 43 deletions

View file

@ -48,6 +48,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
- download_images: If True, downloads images and converts to base64 for providers that require it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier
- provider_data_api_key_field: Optional field name in provider data to look for API key
- list_provider_model_ids: Method to list available models from the provider
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
@ -121,6 +122,30 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
return {}
def construct_model_from_identifier(self, identifier: str) -> Model:
"""
Construct a Model instance corresponding to the given identifier
Child classes can override this to customize model typing/metadata.
:param identifier: The provider's model identifier
:return: A Model instance
"""
if metadata := self.embedding_model_metadata.get(identifier):
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.embedding,
metadata=metadata,
)
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.llm,
)
async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from the provider.
@ -416,21 +441,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
if metadata := self.embedding_model_metadata.get(provider_model_id):
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.llm,
)
model = self.construct_model_from_identifier(provider_model_id)
self._model_cache[provider_model_id] = model
return list(self._model_cache.values())