Fix rerank integration test based on client side changes

This commit is contained in:
Jiayi 2025-10-01 10:37:58 -07:00
parent bb2eb33fc3
commit 6b4940806f
8 changed files with 27 additions and 276 deletions

View file

@ -4,8 +4,7 @@ description: "Llama Stack Inference API for generating completions, chat complet
This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models rerank the documents by relevance."
- Rerank models: these models reorder the documents based on their relevance to a query."
sidebar_label: Inference
title: Inference
---

View file

@ -13335,7 +13335,7 @@
},
{
"name": "Inference",
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
"description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
},
{

View file

@ -9990,13 +9990,16 @@ tags:
description: ''
- name: Inference
description: >-
This API provides the raw interface to the underlying models. Two kinds of models
are supported:
This API provides the raw interface to the underlying models. Three kinds of
models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic
search.
- Rerank models: these models reorder the documents based on their relevance
to a query.
x-displayName: >-
Llama Stack Inference API for generating completions, chat completions, and
embeddings.

View file

@ -8838,7 +8838,8 @@
"type": "string",
"enum": [
"llm",
"embedding"
"embedding",
"rerank"
],
"title": "ModelType",
"description": "Enumeration of supported model types in Llama Stack."
@ -17033,7 +17034,7 @@
"properties": {
"model": {
"type": "string",
"description": "The identifier of the reranking model to use."
"description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint."
},
"query": {
"oneOf": [
@ -18456,7 +18457,7 @@
},
{
"name": "Inference",
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
"description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
},
{

View file

@ -6603,6 +6603,7 @@ components:
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
@ -12693,7 +12694,8 @@ components:
model:
type: string
description: >-
The identifier of the reranking model to use.
The identifier of the reranking model to use. The model must be a reranking
model registered with Llama Stack and available via the /models endpoint.
query:
oneOf:
- type: string
@ -13774,13 +13776,16 @@ tags:
description: ''
- name: Inference
description: >-
This API provides the raw interface to the underlying models. Two kinds of models
are supported:
This API provides the raw interface to the underlying models. Three kinds of
models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic
search.
- Rerank models: these models reorder the documents based on their relevance
to a query.
x-displayName: >-
Llama Stack Inference API for generating completions, chat completions, and
embeddings.

View file

@ -1,257 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
os.environ["NVIDIA_API_KEY"] = "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"
# Option 1: Use default NIM URL (will auto-switch to ai.api.nvidia.com for rerank)
# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com"
# Option 2: Use AI Foundation URL directly for rerank models
# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com/v1"
os.environ["NVIDIA_BASE_URL"] = "https://integrate.api.nvidia.com"
import base64
import io
from PIL import Image
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
# # response = client.inference.completion(
# # model_id="meta/llama-3.1-8b-instruct",
# # content="Complete the sentence using one word: Roses are red, violets are :",
# # stream=False,
# # sampling_params={
# # "max_tokens": 50,
# # },
# # )
# # print(f"Response: {response.content}")
# response = client.inference.chat_completion(
# model_id="nvidia/nvidia-nemotron-nano-9b-v2",
# messages=[
# {
# "role": "system",
# "content": "/think",
# },
# {
# "role": "user",
# "content": "How are you?",
# },
# ],
# stream=False,
# sampling_params={
# "max_tokens": 1024,
# },
# )
# print(f"Response: {response}")
print(client.models.list())
rerank_response = client.inference.rerank(
model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
query="query",
items=[
"item_1",
"item_2",
"item_3",
]
)
print(rerank_response)
for i, result in enumerate(rerank_response):
print(f"{i+1}. [Index: {result.index}, "
f"Score: {(result.relevance_score):.3f}]")
# # from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
# # tool_definition = ToolDefinition(
# # tool_name="get_weather",
# # description="Get current weather information for a location",
# # parameters={
# # "location": ToolParamDefinition(
# # param_type="string",
# # description="The city and state, e.g. San Francisco, CA",
# # required=True
# # ),
# # "unit": ToolParamDefinition(
# # param_type="string",
# # description="Temperature unit (celsius or fahrenheit)",
# # required=False,
# # default="celsius"
# # )
# # }
# # )
# # # tool_response = client.inference.chat_completion(
# # # model_id="meta-llama/Llama-3.1-8B-Instruct",
# # # messages=[
# # # {"role": "user", "content": "What's the weather like in San Francisco?"}
# # # ],
# # # tools=[tool_definition],
# # # )
# # # print(f"Tool Response: {tool_response.completion_message.content}")
# # # if tool_response.completion_message.tool_calls:
# # # for tool_call in tool_response.completion_message.tool_calls:
# # # print(f"Tool Called: {tool_call.tool_name}")
# # # print(f"Arguments: {tool_call.arguments}")
# # # from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
# # # person_schema = {
# # # "type": "object",
# # # "properties": {
# # # "name": {"type": "string"},
# # # "age": {"type": "integer"},
# # # "occupation": {"type": "string"},
# # # },
# # # "required": ["name", "age", "occupation"]
# # # }
# # # response_format = JsonSchemaResponseFormat(
# # # type=ResponseFormatType.json_schema,
# # # json_schema=person_schema
# # # )
# # # structured_response = client.inference.chat_completion(
# # # model_id="meta-llama/Llama-3.1-8B-Instruct",
# # # messages=[
# # # {
# # # "role": "user",
# # # "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. "
# # # }
# # # ],
# # # response_format=response_format,
# # # )
# # # print(f"Structured Response: {structured_response.completion_message.content}")
# # # print("\n" + "="*50)
# # # print("VISION LANGUAGE MODEL (VLM) EXAMPLE")
# # # print("="*50)
# # def load_image_as_base64(image_path):
# # with open(image_path, "rb") as image_file:
# # img_bytes = image_file.read()
# # return base64.b64encode(img_bytes).decode("utf-8")
# # image_path = "/home/jiayin/llama-stack/docs/dog.jpg"
# # demo_image_b64 = load_image_as_base64(image_path)
# # vlm_response = client.inference.chat_completion(
# # model_id="nvidia/vila",
# # messages=[
# # {
# # "role": "user",
# # "content": [
# # {
# # "type": "image",
# # "image": {
# # "data": demo_image_b64,
# # },
# # },
# # {
# # "type": "text",
# # "text": "Please describe what you see in this image in detail.",
# # },
# # ],
# # }
# # ],
# # )
# # print(f"VLM Response: {vlm_response.completion_message.content}")
# # # print("\n" + "="*50)
# # # print("EMBEDDING EXAMPLE")
# # # print("="*50)
# # # # Embedding example
# # # embedding_response = client.inference.embeddings(
# # # model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
# # # contents=["Hello world", "How are you today?"],
# # # task_type="query"
# # # )
# # # print(f"Number of embeddings: {len(embedding_response.embeddings)}")
# # # print(f"Embedding dimension: {len(embedding_response.embeddings[0])}")
# # # print(f"First few values: {embedding_response.embeddings[0][:5]}")
# # # # from openai import OpenAI
# # # # client = OpenAI(
# # # # base_url = "http://10.176.230.61:8000/v1",
# # # # api_key = "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"
# # # # )
# # # # # completion = client.completions.create(
# # # # # model="meta/llama-3.1-405b-instruct",
# # # # # prompt="How are you?",
# # # # # temperature=0.2,
# # # # # top_p=0.7,
# # # # # max_tokens=1024,
# # # # # stream=False
# # # # # )
# # # # # # completion = client.chat.completions.create(
# # # # # # model="meta/llama-3.1-8b-instruct",
# # # # # # messages=[{"role":"user","content":"hi"}],
# # # # # # temperature=0.2,
# # # # # # top_p=0.7,
# # # # # # max_tokens=1024,
# # # # # # stream=True
# # # # # # )
# # # # # for chunk in completion:
# # # # # if chunk.choices[0].delta.content is not None:
# # # # # print(chunk.choices[0].delta.content, end="")
# # # # # response = client.inference.completion(
# # # # # model_id="meta/llama-3.1-8b-instruct",
# # # # # content="Complete the sentence using one word: Roses are red, violets are :",
# # # # # stream=False,
# # # # # sampling_params={
# # # # # "max_tokens": 50,
# # # # # },
# # # # # )
# # # # # print(f"Response: {response.content}")
# from openai import OpenAI
# client = OpenAI(
# base_url = "https://integrate.api.nvidia.com/v1",
# api_key = "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"
# )
# completion = client.chat.completions.create(
# model="nvidia/nvidia-nemotron-nano-9b-v2",
# messages=[{"role":"system","content":"/think"}],
# temperature=0.6,
# top_p=0.95,
# max_tokens=2048,
# frequency_penalty=0,
# presence_penalty=0,
# stream=True,
# extra_body={
# "min_thinking_tokens": 1024,
# "max_thinking_tokens": 2048
# }
# )
# for chunk in completion:
# reasoning = getattr(chunk.choices[0].delta, "reasoning_content", None)
# if reasoning:
# print(reasoning, end="")
# if chunk.choices[0].delta.content is not None:
# print(chunk.choices[0].delta.content, end="")

View file

@ -6,7 +6,7 @@
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import InferenceRerankResponse
from llama_stack_client.types.alpha import InferenceRerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
@ -97,7 +97,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
assert len(response) <= len(items)
@ -129,9 +129,9 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
)
with pytest.raises(error_type):
client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
else:
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
assert len(response) <= len(items)
@ -144,7 +144,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
response = client_with_models.inference.rerank(
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
@ -160,7 +160,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2]
response = client_with_models.inference.rerank(
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
@ -208,7 +208,7 @@ def test_rerank_semantic_correctness(
):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
_validate_rerank_response(response, items)
_validate_semantic_ranking(response, items, expected_first_item)