mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
address feedback
This commit is contained in:
parent
e167e9eb93
commit
5821ec9ef3
12 changed files with 61 additions and 76 deletions
|
@ -205,12 +205,10 @@ API responses, specify the adapter here.
|
||||||
def remote_provider_spec(
|
def remote_provider_spec(
|
||||||
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||||
) -> RemoteProviderSpec:
|
) -> RemoteProviderSpec:
|
||||||
if api_dependencies is None:
|
|
||||||
api_dependencies = []
|
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
config_class=adapter.config_class,
|
config_class=adapter.config_class,
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
api_dependencies=api_dependencies,
|
api_dependencies=api_dependencies or [],
|
||||||
)
|
)
|
||||||
|
|
|
@ -84,7 +84,7 @@ class MetaReferenceInferenceImpl(
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model = await self.model_registry_helper.register_model(model)
|
model = await self.model_registry_helper.register_model(model)
|
||||||
if model.model_type == ModelType.embedding_model:
|
if model.model_type == ModelType.embedding_model:
|
||||||
self._get_embedding_model(model.provider_resource_id)
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
|
@ -48,7 +48,7 @@ class SentenceTransformersInferenceImpl(
|
||||||
)
|
)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> None:
|
||||||
_ = self._get_embedding_model(model.provider_resource_id)
|
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
@ -63,7 +63,7 @@ class SentenceTransformersInferenceImpl(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, AsyncGenerator]:
|
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||||
raise NotImplementedError("Sentence transformers don't support completion")
|
raise ValueError("Sentence transformers don't support completion")
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -77,4 +77,4 @@ class SentenceTransformersInferenceImpl(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError("Sentence transformers don't support chat completion")
|
raise ValueError("Sentence transformers don't support chat completion")
|
||||||
|
|
|
@ -20,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||||
|
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
|
@ -452,7 +454,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for content in contents:
|
for content in contents:
|
||||||
input_text = str(content) if not isinstance(content, str) else content
|
assert not content_has_media(
|
||||||
|
content
|
||||||
|
), "Bedrock does not support media for embeddings"
|
||||||
|
input_text = interleaved_text_media_as_str(content)
|
||||||
input_body = {"inputText": input_text}
|
input_body = {"inputText": input_text}
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(
|
||||||
|
|
|
@ -12,7 +12,6 @@ from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from openai import OpenAI
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
@ -29,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -105,9 +105,6 @@ class FireworksInferenceAdapter(
|
||||||
fireworks_api_key = self._get_api_key()
|
fireworks_api_key = self._get_api_key()
|
||||||
return Fireworks(api_key=fireworks_api_key)
|
return Fireworks(api_key=fireworks_api_key)
|
||||||
|
|
||||||
def _get_openai_client(self) -> OpenAI:
|
|
||||||
return OpenAI(base_url=self.config.url, api_key=self._get_api_key())
|
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -272,12 +269,16 @@ class FireworksInferenceAdapter(
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
client = self._get_openai_client()
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model.metadata.get("embedding_dimensions"):
|
if model.metadata.get("embedding_dimensions"):
|
||||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
response = client.embeddings.create(
|
assert all(
|
||||||
model=model.provider_resource_id, input=contents, **kwargs
|
not content_has_media(content) for content in contents
|
||||||
|
), "Fireworks does not support media for embeddings"
|
||||||
|
response = self._get_client().embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = [data.embedding for data in response.data]
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -323,8 +324,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Ollama does not support media for embeddings"
|
||||||
response = await self.client.embed(
|
response = await self.client.embed(
|
||||||
model=model.provider_resource_id, input=contents
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
)
|
)
|
||||||
embeddings = response["embeddings"]
|
embeddings = response["embeddings"]
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -254,8 +255,12 @@ class TogetherInferenceAdapter(
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Together does not support media for embeddings"
|
||||||
r = self._get_client().embeddings.create(
|
r = self._get_client().embeddings.create(
|
||||||
model=model.provider_resource_id, input=contents
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
)
|
)
|
||||||
embeddings = [item.embedding for item in r.data]
|
embeddings = [item.embedding for item in r.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -206,10 +207,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model.metadata.get("embedding_dimensions"):
|
assert model.model_type == ModelType.embedding_model
|
||||||
|
assert model.metadata.get("embedding_dimensions")
|
||||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "VLLM does not support media for embeddings"
|
||||||
response = self.client.embeddings.create(
|
response = self.client.embeddings.create(
|
||||||
model=model.provider_resource_id, input=contents, **kwargs
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = [data.embedding for data in response.data]
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
|
|
@ -84,24 +84,3 @@ def pytest_generate_tests(metafunc):
|
||||||
):
|
):
|
||||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||||
|
|
||||||
if "embedding_model" in metafunc.fixturenames:
|
|
||||||
model = metafunc.config.getoption("--embedding-model")
|
|
||||||
if not model:
|
|
||||||
raise ValueError(
|
|
||||||
"No embedding model specified. Please provide a valid embedding model."
|
|
||||||
)
|
|
||||||
params = [pytest.param(model, id="")]
|
|
||||||
|
|
||||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
|
||||||
|
|
||||||
if "embedding_stack" in metafunc.fixturenames:
|
|
||||||
fixtures = INFERENCE_FIXTURES
|
|
||||||
if filtered_stacks := get_provider_fixture_overrides(
|
|
||||||
metafunc.config,
|
|
||||||
{
|
|
||||||
"inference": INFERENCE_FIXTURES,
|
|
||||||
},
|
|
||||||
):
|
|
||||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
|
||||||
metafunc.parametrize("embedding_stack", fixtures, indirect=True)
|
|
||||||
|
|
|
@ -37,13 +37,6 @@ def inference_model(request):
|
||||||
return request.config.getoption("--inference-model", None)
|
return request.config.getoption("--inference-model", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def embedding_model(request):
|
|
||||||
if hasattr(request, "param"):
|
|
||||||
return request.param
|
|
||||||
return request.config.getoption("--embedding-model", None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_remote() -> ProviderFixture:
|
def inference_remote() -> ProviderFixture:
|
||||||
return remote_stack_fixture()
|
return remote_stack_fixture()
|
||||||
|
@ -239,31 +232,21 @@ INFERENCE_FIXTURES = [
|
||||||
async def inference_stack(request, inference_model):
|
async def inference_stack(request, inference_model):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
test_stack = await construct_stack_for_test(
|
model_type = ModelType.llm
|
||||||
[Api.inference],
|
metadata = {}
|
||||||
{"inference": inference_fixture.providers},
|
if os.getenv("EMBEDDING_DIMENSION"):
|
||||||
inference_fixture.provider_data,
|
model_type = ModelType.embedding_model
|
||||||
models=[ModelInput(model_id=inference_model)],
|
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
||||||
)
|
|
||||||
|
|
||||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def embedding_stack(request, embedding_model):
|
|
||||||
fixture_name = request.param
|
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.inference],
|
[Api.inference],
|
||||||
{"inference": inference_fixture.providers},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[
|
models=[
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=embedding_model,
|
model_id=inference_model,
|
||||||
model_type=ModelType.embedding_model,
|
model_type=model_type,
|
||||||
metadata={
|
metadata=metadata,
|
||||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,15 +14,15 @@ from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||||
|
|
||||||
class TestEmbeddings:
|
class TestEmbeddings:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_embeddings(self, embedding_model, embedding_stack):
|
async def test_embeddings(self, inference_model, inference_stack):
|
||||||
inference_impl, models_impl = embedding_stack
|
inference_impl, models_impl = inference_stack
|
||||||
model = await models_impl.get_model(embedding_model)
|
model = await models_impl.get_model(inference_model)
|
||||||
|
|
||||||
if model.model_type != ModelType.embedding_model:
|
if model.model_type != ModelType.embedding_model:
|
||||||
pytest.skip("This test is only applicable for embedding models")
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
response = await inference_impl.embeddings(
|
response = await inference_impl.embeddings(
|
||||||
model_id=embedding_model,
|
model_id=inference_model,
|
||||||
contents=["Hello, world!"],
|
contents=["Hello, world!"],
|
||||||
)
|
)
|
||||||
assert isinstance(response, EmbeddingsResponse)
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
@ -35,9 +35,9 @@ class TestEmbeddings:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_batch_embeddings(self, embedding_model, embedding_stack):
|
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||||
inference_impl, models_impl = embedding_stack
|
inference_impl, models_impl = inference_stack
|
||||||
model = await models_impl.get_model(embedding_model)
|
model = await models_impl.get_model(inference_model)
|
||||||
|
|
||||||
if model.model_type != ModelType.embedding_model:
|
if model.model_type != ModelType.embedding_model:
|
||||||
pytest.skip("This test is only applicable for embedding models")
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
@ -45,7 +45,7 @@ class TestEmbeddings:
|
||||||
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||||
|
|
||||||
response = await inference_impl.embeddings(
|
response = await inference_impl.embeddings(
|
||||||
model_id=embedding_model,
|
model_id=inference_model,
|
||||||
contents=texts,
|
contents=texts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -26,11 +26,13 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._get_embedding_model(model.provider_resource_id)
|
embedding_model = self._load_sentence_transformer_model(
|
||||||
|
model.provider_resource_id
|
||||||
|
)
|
||||||
embeddings = embedding_model.encode(contents)
|
embeddings = embedding_model.encode(contents)
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
def _get_embedding_model(self, model: str) -> "SentenceTransformer":
|
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
global EMBEDDING_MODELS
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue