mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Make embedding generation go through inference (#606)
This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598
This commit is contained in:
parent
a14785af46
commit
96e158eaac
37 changed files with 677 additions and 156 deletions
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||
|
||||
# How to run this test:
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||
|
||||
|
||||
class TestEmbeddings:
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings(self, inference_model, inference_stack):
|
||||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding_model:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
model_id=inference_model,
|
||||
contents=["Hello, world!"],
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) > 0
|
||||
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||
assert all(
|
||||
isinstance(value, float)
|
||||
for embedding in response.embeddings
|
||||
for value in embedding
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding_model:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
model_id=inference_model,
|
||||
contents=texts,
|
||||
)
|
||||
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == len(texts)
|
||||
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||
assert all(
|
||||
isinstance(value, float)
|
||||
for embedding in response.embeddings
|
||||
for value in embedding
|
||||
)
|
||||
|
||||
embedding_dim = len(response.embeddings[0])
|
||||
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|
Loading…
Add table
Add a link
Reference in a new issue