mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
chore: made together async and faiss index search as async
This commit is contained in:
parent
b5a505aa8d
commit
501467846d
2 changed files with 21 additions and 32 deletions
|
@ -8,6 +8,7 @@ import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
|
@ -99,7 +100,9 @@ class FaissIndex(EmbeddingIndex):
|
||||||
await self._save_index()
|
await self._save_index()
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k)
|
distances, indices = await asyncio.to_thread(
|
||||||
|
self.index.search, embedding.reshape(1, -1).astype(np.float32), k
|
||||||
|
)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
|
|
|
@ -1,12 +1,6 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from together import Together
|
from together import AsyncTogether
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -91,7 +85,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
def _get_client(self) -> Together:
|
async def _get_client(self) -> AsyncTogether:
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||||
if config_api_key:
|
if config_api_key:
|
||||||
|
@ -103,23 +97,18 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
together_api_key = provider_data.together_api_key
|
together_api_key = provider_data.together_api_key
|
||||||
return Together(api_key=together_api_key)
|
return AsyncTogether(api_key=together_api_key)
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = self._get_client().completions.create(**params)
|
client = await self._get_client()
|
||||||
|
r = await client.completions.create(**params)
|
||||||
return process_completion_response(r)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
client = await self._get_client()
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
stream = await client.completions.create(**params)
|
||||||
async def _to_async_generator():
|
|
||||||
s = self._get_client().completions.create(**params)
|
|
||||||
for chunk in s:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
stream = _to_async_generator()
|
|
||||||
async for chunk in process_completion_stream_response(stream):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -184,25 +173,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
|
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
client = await self._get_client()
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = self._get_client().chat.completions.create(**params)
|
r = await client.chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
r = self._get_client().completions.create(**params)
|
r = await client.completions.create(**params)
|
||||||
return process_chat_completion_response(r, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
client = await self._get_client()
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
|
||||||
async def _to_async_generator():
|
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
s = self._get_client().chat.completions.create(**params)
|
stream = await client.chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
s = self._get_client().completions.create(**params)
|
stream = await client.completions.create(**params)
|
||||||
for chunk in s:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
stream = _to_async_generator()
|
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -240,7 +225,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Together does not support media for embeddings"
|
"Together does not support media for embeddings"
|
||||||
)
|
)
|
||||||
r = self._get_client().embeddings.create(
|
client = await self._get_client()
|
||||||
|
r = await client.embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_content_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue