From 501467846da1499fa89aa80eb7808f72bd5f2b56 Mon Sep 17 00:00:00 2001 From: sarthakdeshpande Date: Sun, 9 Mar 2025 17:35:27 +0530 Subject: [PATCH] chore: made together async and faiss index search as async --- .../providers/inline/vector_io/faiss/faiss.py | 5 +- .../remote/inference/together/together.py | 48 +++++++------------ 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 410d8bd8b..e86872fea 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -8,6 +8,7 @@ import base64 import io import json import logging +import asyncio from typing import Any, Dict, List, Optional import faiss @@ -99,7 +100,9 @@ class FaissIndex(EmbeddingIndex): await self._save_index() async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k) + distances, indices = await asyncio.to_thread( + self.index.search, embedding.reshape(1, -1).astype(np.float32), k + ) chunks = [] scores = [] diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index dfc9ae6d3..f75aba69e 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -1,12 +1,6 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - from typing import AsyncGenerator, List, Optional, Union -from together import Together +from together import AsyncTogether from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -91,7 +85,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi else: return await self._nonstream_completion(request) - def _get_client(self) -> Together: + async def _get_client(self) -> AsyncTogether: together_api_key = None config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None if config_api_key: @@ -103,23 +97,18 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }' ) together_api_key = provider_data.together_api_key - return Together(api_key=together_api_key) + return AsyncTogether(api_key=together_api_key) async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) - r = self._get_client().completions.create(**params) + client = await self._get_client() + r = await client.completions.create(**params) return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - - # if we shift to TogetherAsyncClient, we won't need this wrapper - async def _to_async_generator(): - s = self._get_client().completions.create(**params) - for chunk in s: - yield chunk - - stream = _to_async_generator() + client = await self._get_client() + stream = await client.completions.create(**params) async for chunk in process_completion_stream_response(stream): yield chunk @@ -184,25 +173,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) + client = await self._get_client() if "messages" in params: - r = self._get_client().chat.completions.create(**params) + r = await client.chat.completions.create(**params) else: - r = self._get_client().completions.create(**params) + r = await client.completions.create(**params) return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) + client = await self._get_client() + if "messages" in params: + stream = await client.chat.completions.create(**params) + else: + stream = await client.completions.create(**params) - # if we shift to TogetherAsyncClient, we won't need this wrapper - async def _to_async_generator(): - if "messages" in params: - s = self._get_client().chat.completions.create(**params) - else: - s = self._get_client().completions.create(**params) - for chunk in s: - yield chunk - - stream = _to_async_generator() async for chunk in process_chat_completion_stream_response(stream, request): yield chunk @@ -240,7 +225,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi assert all(not content_has_media(content) for content in contents), ( "Together does not support media for embeddings" ) - r = self._get_client().embeddings.create( + client = await self._get_client() + r = await client.embeddings.create( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], )