From 6b4940806f2aa6b411f14e28f9c1414df9d059e5 Mon Sep 17 00:00:00 2001 From: Jiayi Date: Wed, 1 Oct 2025 10:37:58 -0700 Subject: [PATCH] Fix rerank integration test based on client side changes --- docs/docs/providers/agents/index.mdx | 2 +- docs/docs/providers/inference/index.mdx | 3 +- docs/static/deprecated-llama-stack-spec.html | 2 +- docs/static/deprecated-llama-stack-spec.yaml | 7 +- docs/static/stainless-llama-stack-spec.html | 7 +- docs/static/stainless-llama-stack-spec.yaml | 11 +- example.py | 257 ------------------- tests/integration/inference/test_rerank.py | 14 +- 8 files changed, 27 insertions(+), 276 deletions(-) delete mode 100644 example.py diff --git a/docs/docs/providers/agents/index.mdx b/docs/docs/providers/agents/index.mdx index 200d0119f..06eb104af 100644 --- a/docs/docs/providers/agents/index.mdx +++ b/docs/docs/providers/agents/index.mdx @@ -14,4 +14,4 @@ Agents APIs for creating and interacting with agentic systems. -This section contains documentation for all available providers for the **agents** API. \ No newline at end of file +This section contains documentation for all available providers for the **agents** API. diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index 065f620df..63741f202 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -4,8 +4,7 @@ description: "Llama Stack Inference API for generating completions, chat complet This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. - - Rerank models: these models rerank the documents by relevance." - + - Rerank models: these models reorder the documents based on their relevance to a query." sidebar_label: Inference title: Inference --- diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 7edfe3f5d..f0dd903a6 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -13335,7 +13335,7 @@ }, { "name": "Inference", - "description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.", "x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." }, { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index ca832d46b..48863025f 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -9990,13 +9990,16 @@ tags: description: '' - name: Inference description: >- - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models: these models reorder the documents based on their relevance + to a query. x-displayName: >- Llama Stack Inference API for generating completions, chat completions, and embeddings. diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 7ec48ef74..6bc67536d 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -8838,7 +8838,8 @@ "type": "string", "enum": [ "llm", - "embedding" + "embedding", + "rerank" ], "title": "ModelType", "description": "Enumeration of supported model types in Llama Stack." @@ -17033,7 +17034,7 @@ "properties": { "model": { "type": "string", - "description": "The identifier of the reranking model to use." + "description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint." }, "query": { "oneOf": [ @@ -18456,7 +18457,7 @@ }, { "name": "Inference", - "description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.", "x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." }, { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 3bede159b..8fc70a5cd 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -6603,6 +6603,7 @@ components: enum: - llm - embedding + - rerank title: ModelType description: >- Enumeration of supported model types in Llama Stack. @@ -12693,7 +12694,8 @@ components: model: type: string description: >- - The identifier of the reranking model to use. + The identifier of the reranking model to use. The model must be a reranking + model registered with Llama Stack and available via the /models endpoint. query: oneOf: - type: string @@ -13774,13 +13776,16 @@ tags: description: '' - name: Inference description: >- - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models: these models reorder the documents based on their relevance + to a query. x-displayName: >- Llama Stack Inference API for generating completions, chat completions, and embeddings. diff --git a/example.py b/example.py deleted file mode 100644 index 7e968e24a..000000000 --- a/example.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import os - -os.environ["NVIDIA_API_KEY"] = "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" -# Option 1: Use default NIM URL (will auto-switch to ai.api.nvidia.com for rerank) -# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com" -# Option 2: Use AI Foundation URL directly for rerank models -# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com/v1" -os.environ["NVIDIA_BASE_URL"] = "https://integrate.api.nvidia.com" - -import base64 -import io -from PIL import Image - -from llama_stack.core.library_client import LlamaStackAsLibraryClient - -client = LlamaStackAsLibraryClient("nvidia") -client.initialize() - -# # response = client.inference.completion( -# # model_id="meta/llama-3.1-8b-instruct", -# # content="Complete the sentence using one word: Roses are red, violets are :", -# # stream=False, -# # sampling_params={ -# # "max_tokens": 50, -# # }, -# # ) -# # print(f"Response: {response.content}") - - -# response = client.inference.chat_completion( -# model_id="nvidia/nvidia-nemotron-nano-9b-v2", -# messages=[ -# { -# "role": "system", -# "content": "/think", -# }, -# { -# "role": "user", -# "content": "How are you?", -# }, -# ], -# stream=False, -# sampling_params={ -# "max_tokens": 1024, -# }, -# ) -# print(f"Response: {response}") - - -print(client.models.list()) -rerank_response = client.inference.rerank( - model="nvidia/llama-3.2-nv-rerankqa-1b-v2", - query="query", - items=[ - "item_1", - "item_2", - "item_3", - ] -) - -print(rerank_response) -for i, result in enumerate(rerank_response): - print(f"{i+1}. [Index: {result.index}, " - f"Score: {(result.relevance_score):.3f}]") - -# # from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition - -# # tool_definition = ToolDefinition( -# # tool_name="get_weather", -# # description="Get current weather information for a location", -# # parameters={ -# # "location": ToolParamDefinition( -# # param_type="string", -# # description="The city and state, e.g. San Francisco, CA", -# # required=True -# # ), -# # "unit": ToolParamDefinition( -# # param_type="string", -# # description="Temperature unit (celsius or fahrenheit)", -# # required=False, -# # default="celsius" -# # ) -# # } -# # ) - -# # # tool_response = client.inference.chat_completion( -# # # model_id="meta-llama/Llama-3.1-8B-Instruct", -# # # messages=[ -# # # {"role": "user", "content": "What's the weather like in San Francisco?"} -# # # ], -# # # tools=[tool_definition], -# # # ) - -# # # print(f"Tool Response: {tool_response.completion_message.content}") -# # # if tool_response.completion_message.tool_calls: -# # # for tool_call in tool_response.completion_message.tool_calls: -# # # print(f"Tool Called: {tool_call.tool_name}") -# # # print(f"Arguments: {tool_call.arguments}") - - -# # # from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType - -# # # person_schema = { -# # # "type": "object", -# # # "properties": { -# # # "name": {"type": "string"}, -# # # "age": {"type": "integer"}, -# # # "occupation": {"type": "string"}, -# # # }, -# # # "required": ["name", "age", "occupation"] -# # # } - -# # # response_format = JsonSchemaResponseFormat( -# # # type=ResponseFormatType.json_schema, -# # # json_schema=person_schema -# # # ) - -# # # structured_response = client.inference.chat_completion( -# # # model_id="meta-llama/Llama-3.1-8B-Instruct", -# # # messages=[ -# # # { -# # # "role": "user", -# # # "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. " -# # # } -# # # ], -# # # response_format=response_format, -# # # ) - -# # # print(f"Structured Response: {structured_response.completion_message.content}") - -# # # print("\n" + "="*50) -# # # print("VISION LANGUAGE MODEL (VLM) EXAMPLE") -# # # print("="*50) - -# # def load_image_as_base64(image_path): -# # with open(image_path, "rb") as image_file: -# # img_bytes = image_file.read() -# # return base64.b64encode(img_bytes).decode("utf-8") - -# # image_path = "/home/jiayin/llama-stack/docs/dog.jpg" -# # demo_image_b64 = load_image_as_base64(image_path) - -# # vlm_response = client.inference.chat_completion( -# # model_id="nvidia/vila", -# # messages=[ -# # { -# # "role": "user", -# # "content": [ -# # { -# # "type": "image", -# # "image": { -# # "data": demo_image_b64, -# # }, -# # }, -# # { -# # "type": "text", -# # "text": "Please describe what you see in this image in detail.", -# # }, -# # ], -# # } -# # ], -# # ) - -# # print(f"VLM Response: {vlm_response.completion_message.content}") - -# # # print("\n" + "="*50) -# # # print("EMBEDDING EXAMPLE") -# # # print("="*50) - -# # # # Embedding example -# # # embedding_response = client.inference.embeddings( -# # # model_id="nvidia/llama-3.2-nv-embedqa-1b-v2", -# # # contents=["Hello world", "How are you today?"], -# # # task_type="query" -# # # ) - -# # # print(f"Number of embeddings: {len(embedding_response.embeddings)}") -# # # print(f"Embedding dimension: {len(embedding_response.embeddings[0])}") -# # # print(f"First few values: {embedding_response.embeddings[0][:5]}") - -# # # # from openai import OpenAI - -# # # # client = OpenAI( -# # # # base_url = "http://10.176.230.61:8000/v1", -# # # # api_key = "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" -# # # # ) - -# # # # # completion = client.completions.create( -# # # # # model="meta/llama-3.1-405b-instruct", -# # # # # prompt="How are you?", -# # # # # temperature=0.2, -# # # # # top_p=0.7, -# # # # # max_tokens=1024, -# # # # # stream=False -# # # # # ) - -# # # # # # completion = client.chat.completions.create( -# # # # # # model="meta/llama-3.1-8b-instruct", -# # # # # # messages=[{"role":"user","content":"hi"}], -# # # # # # temperature=0.2, -# # # # # # top_p=0.7, -# # # # # # max_tokens=1024, -# # # # # # stream=True -# # # # # # ) - -# # # # # for chunk in completion: -# # # # # if chunk.choices[0].delta.content is not None: -# # # # # print(chunk.choices[0].delta.content, end="") - - -# # # # # response = client.inference.completion( -# # # # # model_id="meta/llama-3.1-8b-instruct", -# # # # # content="Complete the sentence using one word: Roses are red, violets are :", -# # # # # stream=False, -# # # # # sampling_params={ -# # # # # "max_tokens": 50, -# # # # # }, -# # # # # ) -# # # # # print(f"Response: {response.content}") - - - - -# from openai import OpenAI - -# client = OpenAI( -# base_url = "https://integrate.api.nvidia.com/v1", -# api_key = "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" -# ) - -# completion = client.chat.completions.create( -# model="nvidia/nvidia-nemotron-nano-9b-v2", -# messages=[{"role":"system","content":"/think"}], -# temperature=0.6, -# top_p=0.95, -# max_tokens=2048, -# frequency_penalty=0, -# presence_penalty=0, -# stream=True, -# extra_body={ -# "min_thinking_tokens": 1024, -# "max_thinking_tokens": 2048 -# } -# ) - -# for chunk in completion: -# reasoning = getattr(chunk.choices[0].delta, "reasoning_content", None) -# if reasoning: -# print(reasoning, end="") -# if chunk.choices[0].delta.content is not None: -# print(chunk.choices[0].delta.content, end="") diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py index 4931c3d6c..82f35cd27 100644 --- a/tests/integration/inference/test_rerank.py +++ b/tests/integration/inference/test_rerank.py @@ -6,7 +6,7 @@ import pytest from llama_stack_client import BadRequestError as LlamaStackBadRequestError -from llama_stack_client.types import InferenceRerankResponse +from llama_stack_client.types.alpha import InferenceRerankResponse from llama_stack_client.types.shared.interleaved_content import ( ImageContentItem, ImageContentItemImage, @@ -97,7 +97,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type): skip_if_provider_doesnt_support_rerank(inference_provider_type) - response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) assert isinstance(response, list) # TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client. assert len(response) <= len(items) @@ -129,9 +129,9 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError ) with pytest.raises(error_type): - client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) + client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) else: - response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) assert isinstance(response, list) assert len(response) <= len(items) @@ -144,7 +144,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2] max_num_results = 2 - response = client_with_models.inference.rerank( + response = client_with_models.alpha.inference.rerank( model=rerank_model_id, query=DUMMY_STRING, items=items, @@ -160,7 +160,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i skip_if_provider_doesnt_support_rerank(inference_provider_type) items = [DUMMY_STRING, DUMMY_STRING2] - response = client_with_models.inference.rerank( + response = client_with_models.alpha.inference.rerank( model=rerank_model_id, query=DUMMY_STRING, items=items, @@ -208,7 +208,7 @@ def test_rerank_semantic_correctness( ): skip_if_provider_doesnt_support_rerank(inference_provider_type) - response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) _validate_rerank_response(response, items) _validate_semantic_ranking(response, items, expected_first_item)