Implement embeddings for ollama

This commit is contained in:
krgutha 2024-10-23 17:18:08 -07:00
parent 0cec86453b
commit 302555b11a
4 changed files with 65 additions and 1 deletions

View file

@ -195,6 +195,15 @@ class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@json_schema_type
class EmbeddingRequest(BaseModel):
model: str
contents: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
truncate: Optional[bool] = True
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
@ -241,4 +250,7 @@ class Inference(Protocol):
self,
model: str,
contents: List[InterleavedTextMedia],
sampling_params: Optional[SamplingParams] = SamplingParams(),
truncate: Optional[bool] = True,
logprobs: Optional[LogProbConfig] = None,
) -> EmbeddingsResponse: ...

View file

@ -128,10 +128,16 @@ class InferenceRouter(Inference):
self,
model: str,
contents: List[InterleavedTextMedia],
sampling_params: Optional[SamplingParams] = SamplingParams(),
truncate: Optional[bool] = True,
logprobs: Optional[LogProbConfig] = None,
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(model).embeddings(
model=model,
contents=contents,
sampling_params=sampling_params,
truncate=truncate,
logprobs=logprobs,
)

View file

@ -234,5 +234,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model: str,
contents: List[InterleavedTextMedia],
sampling_params: Optional[SamplingParams] = SamplingParams(),
truncate: Optional[bool] = True,
logprobs: Optional[LogProbConfig] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
request = EmbeddingRequest(
model=model,
contents=contents,
sampling_params=sampling_params,
truncate=truncate,
logprobs=logprobs,
)
return await self._embeddings(request)
async def _embeddings(self, request: EmbeddingRequest) -> EmbeddingsResponse:
params = self._get_params_for_embeddings(request)
r = await self.client.embed(**params)
return EmbeddingsResponse(embeddings=r["embeddings"])
def _get_params_for_embeddings(self, request: EmbeddingRequest) -> dict:
sampling_options = get_sampling_options(request)
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"input": request.contents,
"options": sampling_options,
"truncate": request.truncate,
}

View file

@ -170,6 +170,27 @@ async def test_completion(inference_settings):
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
async def test_embed(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in ("remote::ollama",):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.embeddings(
contents=["Roses are red"],
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) > 0
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]