forked from phoenix-oss/llama-stack-mirror
# What does this PR do? add /v1/inference/embeddings implementation to NVIDIA provider **open topics** - - *asymmetric models*. NeMo Retriever includes asymmetric models, which are models that embed differently depending on if the input is destined for storage or lookup against storage. the /v1/inference/embeddings api does not allow the user to indicate the type of embedding to perform. see https://github.com/meta-llama/llama-stack/issues/934 - *truncation*. embedding models typically have a limited context window, e.g. 1024 tokens is common though newer models have 8k windows. when the input is larger than this window the endpoint cannot perform its designed function. two options: 0. return an error so the user can reduce the input size and retry; 1. perform truncation for the user and proceed (common strategies are left or right truncation). many users encounter context window size limits and will struggle to write reliable programs. this struggle is especially acute without access to the model's tokenizer. the /v1/inference/embeddings api does not allow the user to delegate truncation policy. see https://github.com/meta-llama/llama-stack/issues/933 - *dimensions*. "Matryoshka" embedding models are available. they allow users to control the number of embedding dimensions the model produces. this is a critical feature for managing storage constraints. embeddings of 1024 dimensions what achieve 95% recall for an application may not be worth the storage cost if a 512 dimensions can achieve 93% recall. controlling embedding dimensions allows applications to determine their recall and storage tradeoffs. the /v1/inference/embeddings api does not allow the user to control the output dimensions. see https://github.com/meta-llama/llama-stack/issues/932 ## Test Plan - `llama stack run llama_stack/templates/nvidia/run.yaml` - `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_embedding.py --embedding-model baai/bge-m3` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
119 lines
3.9 KiB
Python
119 lines
3.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import os
|
|
|
|
import pytest
|
|
from llama_stack_client import LlamaStackClient
|
|
from report import Report
|
|
|
|
from llama_stack import LlamaStackAsLibraryClient
|
|
from llama_stack.providers.tests.env import get_env_or_fail
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.option.tbstyle = "short"
|
|
config.option.disable_warnings = True
|
|
# Note:
|
|
# if report_path is not provided (aka no option --report in the pytest command),
|
|
# it will be set to False
|
|
# if --report will give None ( in this case we infer report_path)
|
|
# if --report /a/b is provided, it will be set to the path provided
|
|
# We want to handle all these cases and hence explicitly check for False
|
|
report_path = config.getoption("--report")
|
|
if report_path is not False:
|
|
config.pluginmanager.register(Report(report_path))
|
|
|
|
|
|
TEXT_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
|
|
VISION_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption(
|
|
"--report",
|
|
action="store",
|
|
default=False,
|
|
nargs="?",
|
|
type=str,
|
|
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
|
|
)
|
|
parser.addoption(
|
|
"--inference-model",
|
|
action="store",
|
|
default=TEXT_MODEL,
|
|
help="Specify the inference model to use for testing",
|
|
)
|
|
parser.addoption(
|
|
"--vision-inference-model",
|
|
action="store",
|
|
default=VISION_MODEL,
|
|
help="Specify the vision inference model to use for testing",
|
|
)
|
|
parser.addoption(
|
|
"--safety-shield",
|
|
action="store",
|
|
default="meta-llama/Llama-Guard-3-1B",
|
|
help="Specify the safety shield model to use for testing",
|
|
)
|
|
parser.addoption(
|
|
"--embedding-model",
|
|
action="store",
|
|
default=TEXT_MODEL,
|
|
help="Specify the embedding model to use for testing",
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def provider_data():
|
|
# check env for tavily secret, brave secret and inject all into provider data
|
|
provider_data = {}
|
|
if os.environ.get("TAVILY_SEARCH_API_KEY"):
|
|
provider_data["tavily_search_api_key"] = os.environ["TAVILY_SEARCH_API_KEY"]
|
|
if os.environ.get("BRAVE_SEARCH_API_KEY"):
|
|
provider_data["brave_search_api_key"] = os.environ["BRAVE_SEARCH_API_KEY"]
|
|
return provider_data if len(provider_data) > 0 else None
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def llama_stack_client(provider_data):
|
|
if os.environ.get("LLAMA_STACK_CONFIG"):
|
|
client = LlamaStackAsLibraryClient(
|
|
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
|
provider_data=provider_data,
|
|
skip_logger_removal=True,
|
|
)
|
|
if not client.initialize():
|
|
raise RuntimeError("Initialization failed")
|
|
|
|
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
|
client = LlamaStackClient(
|
|
base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"),
|
|
provider_data=provider_data,
|
|
)
|
|
else:
|
|
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
|
return client
|
|
|
|
|
|
def pytest_generate_tests(metafunc):
|
|
if "text_model_id" in metafunc.fixturenames:
|
|
metafunc.parametrize(
|
|
"text_model_id",
|
|
[metafunc.config.getoption("--inference-model")],
|
|
scope="session",
|
|
)
|
|
if "vision_model_id" in metafunc.fixturenames:
|
|
metafunc.parametrize(
|
|
"vision_model_id",
|
|
[metafunc.config.getoption("--vision-inference-model")],
|
|
scope="session",
|
|
)
|
|
if "embedding_model_id" in metafunc.fixturenames:
|
|
metafunc.parametrize(
|
|
"embedding_model_id",
|
|
[metafunc.config.getoption("--embedding-model")],
|
|
scope="session",
|
|
)
|