mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-20 06:38:43 +00:00
implement embedding generation in supported inference providers (#589)
This PR adds the ability to generate embeddings in all supported inference providers. ``` pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0" --env EMBEDDING_DIMENSION=1024 pytest -v -s -k "vllm" --inferrence-model="intfloat/e5-mistral-7b-instruct" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096 --env VLLM_URL="http://localhost:9798/v1" pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5" llama_stack/providers/tests/inference/test_embeddings.py -k "fireworks" --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128 pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval" llama_stack/providers/tests/inference/test_embeddings.py -k "together" --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768 pytest -v -s -k "ollama" --inference-model="all-minilm:v8" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 ```
This commit is contained in:
parent
6a23f24ee0
commit
d362d2d740
32 changed files with 597 additions and 143 deletions
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
|||
@json_schema_type
|
||||
class FireworksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference",
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference",
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
|
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
|
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = None
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
fireworks_api_key = self.config.api_key
|
||||
return self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
raise ValueError(
|
||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||
)
|
||||
fireworks_api_key = provider_data.fireworks_api_key
|
||||
return provider_data.fireworks_api_key
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self._get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
async def completion(
|
||||
|
|
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimensions"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Fireworks does not support media for embeddings"
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue