From 302555b11abe61b8c5bd92710dd23535fe2c4355 Mon Sep 17 00:00:00 2001 From: krgutha Date: Wed, 23 Oct 2024 17:18:08 -0700 Subject: [PATCH] Implement embeddings for ollama --- llama_stack/apis/inference/inference.py | 12 +++++++++ llama_stack/distribution/routers/routers.py | 6 +++++ .../adapters/inference/ollama/ollama.py | 27 ++++++++++++++++++- .../tests/inference/test_inference.py | 21 +++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4ee01acae..e1bba5c0b 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -195,6 +195,15 @@ class BatchChatCompletionResponse(BaseModel): batch: List[ChatCompletionResponse] +@json_schema_type +class EmbeddingRequest(BaseModel): + model: str + contents: List[InterleavedTextMedia] + sampling_params: Optional[SamplingParams] = SamplingParams() + truncate: Optional[bool] = True + logprobs: Optional[LogProbConfig] = None + + @json_schema_type class EmbeddingsResponse(BaseModel): embeddings: List[List[float]] @@ -241,4 +250,7 @@ class Inference(Protocol): self, model: str, contents: List[InterleavedTextMedia], + sampling_params: Optional[SamplingParams] = SamplingParams(), + truncate: Optional[bool] = True, + logprobs: Optional[LogProbConfig] = None, ) -> EmbeddingsResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 31b8efa48..d957917bd 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -128,10 +128,16 @@ class InferenceRouter(Inference): self, model: str, contents: List[InterleavedTextMedia], + sampling_params: Optional[SamplingParams] = SamplingParams(), + truncate: Optional[bool] = True, + logprobs: Optional[LogProbConfig] = None, ) -> EmbeddingsResponse: return await self.routing_table.get_provider_impl(model).embeddings( model=model, contents=contents, + sampling_params=sampling_params, + truncate=truncate, + logprobs=logprobs, ) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index d4fe75cfa..bf7cdf75a 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -234,5 +234,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, model: str, contents: List[InterleavedTextMedia], + sampling_params: Optional[SamplingParams] = SamplingParams(), + truncate: Optional[bool] = True, + logprobs: Optional[LogProbConfig] = None, ) -> EmbeddingsResponse: - raise NotImplementedError() + + request = EmbeddingRequest( + model=model, + contents=contents, + sampling_params=sampling_params, + truncate=truncate, + logprobs=logprobs, + ) + return await self._embeddings(request) + + async def _embeddings(self, request: EmbeddingRequest) -> EmbeddingsResponse: + params = self._get_params_for_embeddings(request) + r = await self.client.embed(**params) + return EmbeddingsResponse(embeddings=r["embeddings"]) + + def _get_params_for_embeddings(self, request: EmbeddingRequest) -> dict: + sampling_options = get_sampling_options(request) + return { + "model": OLLAMA_SUPPORTED_MODELS[request.model], + "input": request.contents, + "options": sampling_options, + "truncate": request.truncate, + } diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index ad49448e2..60c3d14d2 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -170,6 +170,27 @@ async def test_completion(inference_settings): assert last.stop_reason == StopReason.out_of_tokens +@pytest.mark.asyncio +async def test_embed(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_spec__.provider_type not in ("remote::ollama",): + pytest.skip("Other inference providers don't support completion() yet") + + response = await inference_impl.embeddings( + contents=["Roses are red"], + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + ) + + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) > 0 + + @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"]