mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 20:49:48 +00:00
This PR adds the ability to generate embeddings in all supported inference providers. ``` pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0" --env EMBEDDING_DIMENSION=1024 pytest -v -s -k "vllm" --inferrence-model="intfloat/e5-mistral-7b-instruct" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096 --env VLLM_URL="http://localhost:9798/v1" pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5" llama_stack/providers/tests/inference/test_embeddings.py -k "fireworks" --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128 pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval" llama_stack/providers/tests/inference/test_embeddings.py -k "together" --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768 pytest -v -s -k "ollama" --inference-model="all-minilm:v8" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 ```
62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
|
|
|
# How to run this test:
|
|
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
|
|
|
|
|
class TestEmbeddings:
|
|
@pytest.mark.asyncio
|
|
async def test_embeddings(self, inference_model, inference_stack):
|
|
inference_impl, models_impl = inference_stack
|
|
model = await models_impl.get_model(inference_model)
|
|
|
|
if model.model_type != ModelType.embedding_model:
|
|
pytest.skip("This test is only applicable for embedding models")
|
|
|
|
response = await inference_impl.embeddings(
|
|
model_id=inference_model,
|
|
contents=["Hello, world!"],
|
|
)
|
|
assert isinstance(response, EmbeddingsResponse)
|
|
assert len(response.embeddings) > 0
|
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
|
assert all(
|
|
isinstance(value, float)
|
|
for embedding in response.embeddings
|
|
for value in embedding
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_embeddings(self, inference_model, inference_stack):
|
|
inference_impl, models_impl = inference_stack
|
|
model = await models_impl.get_model(inference_model)
|
|
|
|
if model.model_type != ModelType.embedding_model:
|
|
pytest.skip("This test is only applicable for embedding models")
|
|
|
|
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
|
|
|
response = await inference_impl.embeddings(
|
|
model_id=inference_model,
|
|
contents=texts,
|
|
)
|
|
|
|
assert isinstance(response, EmbeddingsResponse)
|
|
assert len(response.embeddings) == len(texts)
|
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
|
assert all(
|
|
isinstance(value, float)
|
|
for embedding in response.embeddings
|
|
for value in embedding
|
|
)
|
|
|
|
embedding_dim = len(response.embeddings[0])
|
|
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|