mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Fix rerank integration test based on client side changes
This commit is contained in:
parent
bb2eb33fc3
commit
6b4940806f
8 changed files with 27 additions and 276 deletions
|
@ -14,4 +14,4 @@ Agents
|
|||
|
||||
APIs for creating and interacting with agentic systems.
|
||||
|
||||
This section contains documentation for all available providers for the **agents** API.
|
||||
This section contains documentation for all available providers for the **agents** API.
|
||||
|
|
|
@ -4,8 +4,7 @@ description: "Llama Stack Inference API for generating completions, chat complet
|
|||
This API provides the raw interface to the underlying models. Three kinds of models are supported:
|
||||
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
- Rerank models: these models rerank the documents by relevance."
|
||||
|
||||
- Rerank models: these models reorder the documents based on their relevance to a query."
|
||||
sidebar_label: Inference
|
||||
title: Inference
|
||||
---
|
||||
|
|
2
docs/static/deprecated-llama-stack-spec.html
vendored
2
docs/static/deprecated-llama-stack-spec.html
vendored
|
@ -13335,7 +13335,7 @@
|
|||
},
|
||||
{
|
||||
"name": "Inference",
|
||||
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
|
||||
"description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
|
||||
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
|
||||
},
|
||||
{
|
||||
|
|
7
docs/static/deprecated-llama-stack-spec.yaml
vendored
7
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
@ -9990,13 +9990,16 @@ tags:
|
|||
description: ''
|
||||
- name: Inference
|
||||
description: >-
|
||||
This API provides the raw interface to the underlying models. Two kinds of models
|
||||
are supported:
|
||||
This API provides the raw interface to the underlying models. Three kinds of
|
||||
models are supported:
|
||||
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
|
||||
- Embedding models: these models generate embeddings to be used for semantic
|
||||
search.
|
||||
|
||||
- Rerank models: these models reorder the documents based on their relevance
|
||||
to a query.
|
||||
x-displayName: >-
|
||||
Llama Stack Inference API for generating completions, chat completions, and
|
||||
embeddings.
|
||||
|
|
7
docs/static/stainless-llama-stack-spec.html
vendored
7
docs/static/stainless-llama-stack-spec.html
vendored
|
@ -8838,7 +8838,8 @@
|
|||
"type": "string",
|
||||
"enum": [
|
||||
"llm",
|
||||
"embedding"
|
||||
"embedding",
|
||||
"rerank"
|
||||
],
|
||||
"title": "ModelType",
|
||||
"description": "Enumeration of supported model types in Llama Stack."
|
||||
|
@ -17033,7 +17034,7 @@
|
|||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the reranking model to use."
|
||||
"description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint."
|
||||
},
|
||||
"query": {
|
||||
"oneOf": [
|
||||
|
@ -18456,7 +18457,7 @@
|
|||
},
|
||||
{
|
||||
"name": "Inference",
|
||||
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
|
||||
"description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
|
||||
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
|
||||
},
|
||||
{
|
||||
|
|
11
docs/static/stainless-llama-stack-spec.yaml
vendored
11
docs/static/stainless-llama-stack-spec.yaml
vendored
|
@ -6603,6 +6603,7 @@ components:
|
|||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
|
@ -12693,7 +12694,8 @@ components:
|
|||
model:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the reranking model to use.
|
||||
The identifier of the reranking model to use. The model must be a reranking
|
||||
model registered with Llama Stack and available via the /models endpoint.
|
||||
query:
|
||||
oneOf:
|
||||
- type: string
|
||||
|
@ -13774,13 +13776,16 @@ tags:
|
|||
description: ''
|
||||
- name: Inference
|
||||
description: >-
|
||||
This API provides the raw interface to the underlying models. Two kinds of models
|
||||
are supported:
|
||||
This API provides the raw interface to the underlying models. Three kinds of
|
||||
models are supported:
|
||||
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
|
||||
- Embedding models: these models generate embeddings to be used for semantic
|
||||
search.
|
||||
|
||||
- Rerank models: these models reorder the documents based on their relevance
|
||||
to a query.
|
||||
x-displayName: >-
|
||||
Llama Stack Inference API for generating completions, chat completions, and
|
||||
embeddings.
|
||||
|
|
257
example.py
257
example.py
|
@ -1,257 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = "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"
|
||||
# Option 1: Use default NIM URL (will auto-switch to ai.api.nvidia.com for rerank)
|
||||
# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com"
|
||||
# Option 2: Use AI Foundation URL directly for rerank models
|
||||
# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com/v1"
|
||||
os.environ["NVIDIA_BASE_URL"] = "https://integrate.api.nvidia.com"
|
||||
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
|
||||
# # response = client.inference.completion(
|
||||
# # model_id="meta/llama-3.1-8b-instruct",
|
||||
# # content="Complete the sentence using one word: Roses are red, violets are :",
|
||||
# # stream=False,
|
||||
# # sampling_params={
|
||||
# # "max_tokens": 50,
|
||||
# # },
|
||||
# # )
|
||||
# # print(f"Response: {response.content}")
|
||||
|
||||
|
||||
# response = client.inference.chat_completion(
|
||||
# model_id="nvidia/nvidia-nemotron-nano-9b-v2",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "system",
|
||||
# "content": "/think",
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "How are you?",
|
||||
# },
|
||||
# ],
|
||||
# stream=False,
|
||||
# sampling_params={
|
||||
# "max_tokens": 1024,
|
||||
# },
|
||||
# )
|
||||
# print(f"Response: {response}")
|
||||
|
||||
|
||||
print(client.models.list())
|
||||
rerank_response = client.inference.rerank(
|
||||
model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
query="query",
|
||||
items=[
|
||||
"item_1",
|
||||
"item_2",
|
||||
"item_3",
|
||||
]
|
||||
)
|
||||
|
||||
print(rerank_response)
|
||||
for i, result in enumerate(rerank_response):
|
||||
print(f"{i+1}. [Index: {result.index}, "
|
||||
f"Score: {(result.relevance_score):.3f}]")
|
||||
|
||||
# # from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
|
||||
# # tool_definition = ToolDefinition(
|
||||
# # tool_name="get_weather",
|
||||
# # description="Get current weather information for a location",
|
||||
# # parameters={
|
||||
# # "location": ToolParamDefinition(
|
||||
# # param_type="string",
|
||||
# # description="The city and state, e.g. San Francisco, CA",
|
||||
# # required=True
|
||||
# # ),
|
||||
# # "unit": ToolParamDefinition(
|
||||
# # param_type="string",
|
||||
# # description="Temperature unit (celsius or fahrenheit)",
|
||||
# # required=False,
|
||||
# # default="celsius"
|
||||
# # )
|
||||
# # }
|
||||
# # )
|
||||
|
||||
# # # tool_response = client.inference.chat_completion(
|
||||
# # # model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
# # # messages=[
|
||||
# # # {"role": "user", "content": "What's the weather like in San Francisco?"}
|
||||
# # # ],
|
||||
# # # tools=[tool_definition],
|
||||
# # # )
|
||||
|
||||
# # # print(f"Tool Response: {tool_response.completion_message.content}")
|
||||
# # # if tool_response.completion_message.tool_calls:
|
||||
# # # for tool_call in tool_response.completion_message.tool_calls:
|
||||
# # # print(f"Tool Called: {tool_call.tool_name}")
|
||||
# # # print(f"Arguments: {tool_call.arguments}")
|
||||
|
||||
|
||||
# # # from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||
|
||||
# # # person_schema = {
|
||||
# # # "type": "object",
|
||||
# # # "properties": {
|
||||
# # # "name": {"type": "string"},
|
||||
# # # "age": {"type": "integer"},
|
||||
# # # "occupation": {"type": "string"},
|
||||
# # # },
|
||||
# # # "required": ["name", "age", "occupation"]
|
||||
# # # }
|
||||
|
||||
# # # response_format = JsonSchemaResponseFormat(
|
||||
# # # type=ResponseFormatType.json_schema,
|
||||
# # # json_schema=person_schema
|
||||
# # # )
|
||||
|
||||
# # # structured_response = client.inference.chat_completion(
|
||||
# # # model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
# # # messages=[
|
||||
# # # {
|
||||
# # # "role": "user",
|
||||
# # # "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. "
|
||||
# # # }
|
||||
# # # ],
|
||||
# # # response_format=response_format,
|
||||
# # # )
|
||||
|
||||
# # # print(f"Structured Response: {structured_response.completion_message.content}")
|
||||
|
||||
# # # print("\n" + "="*50)
|
||||
# # # print("VISION LANGUAGE MODEL (VLM) EXAMPLE")
|
||||
# # # print("="*50)
|
||||
|
||||
# # def load_image_as_base64(image_path):
|
||||
# # with open(image_path, "rb") as image_file:
|
||||
# # img_bytes = image_file.read()
|
||||
# # return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
# # image_path = "/home/jiayin/llama-stack/docs/dog.jpg"
|
||||
# # demo_image_b64 = load_image_as_base64(image_path)
|
||||
|
||||
# # vlm_response = client.inference.chat_completion(
|
||||
# # model_id="nvidia/vila",
|
||||
# # messages=[
|
||||
# # {
|
||||
# # "role": "user",
|
||||
# # "content": [
|
||||
# # {
|
||||
# # "type": "image",
|
||||
# # "image": {
|
||||
# # "data": demo_image_b64,
|
||||
# # },
|
||||
# # },
|
||||
# # {
|
||||
# # "type": "text",
|
||||
# # "text": "Please describe what you see in this image in detail.",
|
||||
# # },
|
||||
# # ],
|
||||
# # }
|
||||
# # ],
|
||||
# # )
|
||||
|
||||
# # print(f"VLM Response: {vlm_response.completion_message.content}")
|
||||
|
||||
# # # print("\n" + "="*50)
|
||||
# # # print("EMBEDDING EXAMPLE")
|
||||
# # # print("="*50)
|
||||
|
||||
# # # # Embedding example
|
||||
# # # embedding_response = client.inference.embeddings(
|
||||
# # # model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
# # # contents=["Hello world", "How are you today?"],
|
||||
# # # task_type="query"
|
||||
# # # )
|
||||
|
||||
# # # print(f"Number of embeddings: {len(embedding_response.embeddings)}")
|
||||
# # # print(f"Embedding dimension: {len(embedding_response.embeddings[0])}")
|
||||
# # # print(f"First few values: {embedding_response.embeddings[0][:5]}")
|
||||
|
||||
# # # # from openai import OpenAI
|
||||
|
||||
# # # # client = OpenAI(
|
||||
# # # # base_url = "http://10.176.230.61:8000/v1",
|
||||
# # # # api_key = "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"
|
||||
# # # # )
|
||||
|
||||
# # # # # completion = client.completions.create(
|
||||
# # # # # model="meta/llama-3.1-405b-instruct",
|
||||
# # # # # prompt="How are you?",
|
||||
# # # # # temperature=0.2,
|
||||
# # # # # top_p=0.7,
|
||||
# # # # # max_tokens=1024,
|
||||
# # # # # stream=False
|
||||
# # # # # )
|
||||
|
||||
# # # # # # completion = client.chat.completions.create(
|
||||
# # # # # # model="meta/llama-3.1-8b-instruct",
|
||||
# # # # # # messages=[{"role":"user","content":"hi"}],
|
||||
# # # # # # temperature=0.2,
|
||||
# # # # # # top_p=0.7,
|
||||
# # # # # # max_tokens=1024,
|
||||
# # # # # # stream=True
|
||||
# # # # # # )
|
||||
|
||||
# # # # # for chunk in completion:
|
||||
# # # # # if chunk.choices[0].delta.content is not None:
|
||||
# # # # # print(chunk.choices[0].delta.content, end="")
|
||||
|
||||
|
||||
# # # # # response = client.inference.completion(
|
||||
# # # # # model_id="meta/llama-3.1-8b-instruct",
|
||||
# # # # # content="Complete the sentence using one word: Roses are red, violets are :",
|
||||
# # # # # stream=False,
|
||||
# # # # # sampling_params={
|
||||
# # # # # "max_tokens": 50,
|
||||
# # # # # },
|
||||
# # # # # )
|
||||
# # # # # print(f"Response: {response.content}")
|
||||
|
||||
|
||||
|
||||
|
||||
# from openai import OpenAI
|
||||
|
||||
# client = OpenAI(
|
||||
# base_url = "https://integrate.api.nvidia.com/v1",
|
||||
# api_key = "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"
|
||||
# )
|
||||
|
||||
# completion = client.chat.completions.create(
|
||||
# model="nvidia/nvidia-nemotron-nano-9b-v2",
|
||||
# messages=[{"role":"system","content":"/think"}],
|
||||
# temperature=0.6,
|
||||
# top_p=0.95,
|
||||
# max_tokens=2048,
|
||||
# frequency_penalty=0,
|
||||
# presence_penalty=0,
|
||||
# stream=True,
|
||||
# extra_body={
|
||||
# "min_thinking_tokens": 1024,
|
||||
# "max_thinking_tokens": 2048
|
||||
# }
|
||||
# )
|
||||
|
||||
# for chunk in completion:
|
||||
# reasoning = getattr(chunk.choices[0].delta, "reasoning_content", None)
|
||||
# if reasoning:
|
||||
# print(reasoning, end="")
|
||||
# if chunk.choices[0].delta.content is not None:
|
||||
# print(chunk.choices[0].delta.content, end="")
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types import InferenceRerankResponse
|
||||
from llama_stack_client.types.alpha import InferenceRerankResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
|
@ -97,7 +97,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e
|
|||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, list)
|
||||
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||
assert len(response) <= len(items)
|
||||
|
@ -129,9 +129,9 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
|
|||
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
else:
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items)
|
||||
|
@ -144,7 +144,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
|
|||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
||||
response = client_with_models.inference.rerank(
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
|
@ -160,7 +160,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
|
|||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = client_with_models.inference.rerank(
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
|
@ -208,7 +208,7 @@ def test_rerank_semantic_correctness(
|
|||
):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
_validate_rerank_response(response, items)
|
||||
_validate_semantic_ranking(response, items, expected_first_item)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue