mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: Add rerank API for NVIDIA Inference Provider (#3329)
# What does this PR do? Add rerank API for NVIDIA Inference Provider. <!-- If resolving an issue, uncomment and update the line below --> Closes #3278 ## Test Plan Unit test: ``` pytest tests/unit/providers/nvidia/test_rerank_inference.py ``` Integration test: ``` pytest -s -v tests/integration/inference/test_rerank.py --stack-config="inference=nvidia" --rerank-model=nvidia/nvidia/nv-rerankqa-mistral-4b-v3 --env NVIDIA_API_KEY="" --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" ```
This commit is contained in:
parent
c396de57a4
commit
fa7699d2c3
8 changed files with 622 additions and 1 deletions
|
|
@ -153,6 +153,7 @@ def client_with_models(
|
|||
vision_model_id,
|
||||
embedding_model_id,
|
||||
judge_model_id,
|
||||
rerank_model_id,
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
|
|
@ -170,6 +171,9 @@ def client_with_models(
|
|||
|
||||
if embedding_model_id and embedding_model_id not in model_ids:
|
||||
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
|
||||
|
||||
if rerank_model_id and rerank_model_id not in model_ids:
|
||||
raise ValueError(f"rerank_model_id {rerank_model_id} not found")
|
||||
return client
|
||||
|
||||
|
||||
|
|
@ -185,7 +189,14 @@ def model_providers(llama_stack_client):
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_if_no_model(request):
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
||||
model_fixtures = [
|
||||
"text_model_id",
|
||||
"vision_model_id",
|
||||
"embedding_model_id",
|
||||
"judge_model_id",
|
||||
"shield_id",
|
||||
"rerank_model_id",
|
||||
]
|
||||
test_func = request.node.function
|
||||
|
||||
actual_params = inspect.signature(test_func).parameters.keys()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue