diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 358a29d4c..58f92c829 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -33,7 +33,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, - OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -46,6 +45,9 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -386,7 +388,42 @@ class OllamaInferenceAdapter( dimensions: int | None = None, user: str | None = None, ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + model_obj = await self._get_model(model) + if model_obj.model_type != ModelType.embedding: + raise ValueError(f"Model {model} is not an embedding model") + + params = { + "model": model_obj.provider_resource_id, + "input": input, + } + + # Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters + if encoding_format is not None: + params["encoding_format"] = encoding_format + if dimensions is not None: + params["dimensions"] = str(dimensions) + if user is not None: + params["user"] = user + + response = await self.openai_client.embeddings.create(**params) + data = [] + for i, embedding_data in enumerate(response.data): + data.append( + OpenAIEmbeddingData( + embedding=embedding_data.embedding, + index=i, + ) + ) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + return OpenAIEmbeddingsResponse( + data=data, + model=response.model, + usage=usage, + ) async def openai_completion( self, diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 90a91a206..1b8bd9038 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -51,7 +51,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id): "remote::runpod", "remote::sambanova", "remote::tgi", - "remote::ollama", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")