Update tests

This commit is contained in:
Jiayi 2025-09-09 16:21:00 -07:00
parent d78e30fe8b
commit f66718be80
2 changed files with 19 additions and 12 deletions

View file

@ -153,10 +153,17 @@ def client_with_models(
metadata={"embedding_dimension": embedding_dimension or 384},
)
if rerank_model_id and rerank_model_id not in model_ids:
rerank_provider = providers[0]
selected_provider = None
for p in providers:
# Currently only NVIDIA inference provider supports reranking
if p.provider_type == "remote::nvidia":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register(
model_id=rerank_model_id,
provider_id=rerank_provider.provider_id,
provider_id=selected_provider.provider_id,
model_type="rerank",
)
return client