mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Implement embeddings for ollama
This commit is contained in:
parent
0cec86453b
commit
302555b11a
4 changed files with 65 additions and 1 deletions
|
@ -195,6 +195,15 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
batch: List[ChatCompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EmbeddingRequest(BaseModel):
|
||||
model: str
|
||||
contents: List[InterleavedTextMedia]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
truncate: Optional[bool] = True
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
embeddings: List[List[float]]
|
||||
|
@ -241,4 +250,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
truncate: Optional[bool] = True,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> EmbeddingsResponse: ...
|
||||
|
|
|
@ -128,10 +128,16 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
truncate: Optional[bool] = True,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
||||
model=model,
|
||||
contents=contents,
|
||||
sampling_params=sampling_params,
|
||||
truncate=truncate,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -234,5 +234,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
truncate: Optional[bool] = True,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
request = EmbeddingRequest(
|
||||
model=model,
|
||||
contents=contents,
|
||||
sampling_params=sampling_params,
|
||||
truncate=truncate,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
return await self._embeddings(request)
|
||||
|
||||
async def _embeddings(self, request: EmbeddingRequest) -> EmbeddingsResponse:
|
||||
params = self._get_params_for_embeddings(request)
|
||||
r = await self.client.embed(**params)
|
||||
return EmbeddingsResponse(embeddings=r["embeddings"])
|
||||
|
||||
def _get_params_for_embeddings(self, request: EmbeddingRequest) -> dict:
|
||||
sampling_options = get_sampling_options(request)
|
||||
return {
|
||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||
"input": request.contents,
|
||||
"options": sampling_options,
|
||||
"truncate": request.truncate,
|
||||
}
|
||||
|
|
|
@ -170,6 +170,27 @@ async def test_completion(inference_settings):
|
|||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_spec__.provider_type not in ("remote::ollama",):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
contents=["Roses are red"],
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue