From 921f8b1125b22ae856bf481a4ea2f1a09544ff8e Mon Sep 17 00:00:00 2001
From: Sarthak Deshpande <60317842+cheesecake100201@users.noreply.github.com>
Date: Tue, 11 Mar 2025 03:55:01 +0530
Subject: [PATCH] chore: Together async client (#1510)
# What does this PR do?
Uses together async client instead of sync client
[//]: # (If resolving an issue, uncomment and update the line below)
## Test Plan
Command to run the test is in the image below(2 tests fail, and they
were failing for the old stable version as well with the same errors.)
[//]: # (## Documentation)
---------
Co-authored-by: sarthakdeshpande
---
.../remote/inference/together/together.py | 71 +++++++++----------
1 file changed, 34 insertions(+), 37 deletions(-)
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index dfc9ae6d3..a4e02f2cb 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -6,7 +6,7 @@
from typing import AsyncGenerator, List, Optional, Union
-from together import Together
+from together import AsyncTogether
from llama_stack.apis.common.content_types import (
InterleavedContent,
@@ -59,12 +59,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
+ self._client = None
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
- pass
+ if self._client:
+ await self._client.close()
+ self._client = None
async def completion(
self,
@@ -91,35 +94,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
else:
return await self._nonstream_completion(request)
- def _get_client(self) -> Together:
- together_api_key = None
- config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
- if config_api_key:
- together_api_key = config_api_key
- else:
- provider_data = self.get_request_provider_data()
- if provider_data is None or not provider_data.together_api_key:
- raise ValueError(
- 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }'
- )
- together_api_key = provider_data.together_api_key
- return Together(api_key=together_api_key)
+ def _get_client(self) -> AsyncTogether:
+ if not self._client:
+ together_api_key = None
+ config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
+ if config_api_key:
+ together_api_key = config_api_key
+ else:
+ provider_data = self.get_request_provider_data()
+ if provider_data is None or not provider_data.together_api_key:
+ raise ValueError(
+ 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }'
+ )
+ together_api_key = provider_data.together_api_key
+ self._client = AsyncTogether(api_key=together_api_key)
+ return self._client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
- r = self._get_client().completions.create(**params)
+ client = self._get_client()
+ r = await client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
-
- # if we shift to TogetherAsyncClient, we won't need this wrapper
- async def _to_async_generator():
- s = self._get_client().completions.create(**params)
- for chunk in s:
- yield chunk
-
- stream = _to_async_generator()
+ client = await self._get_client()
+ stream = await client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
@@ -184,25 +184,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
+ client = self._get_client()
if "messages" in params:
- r = self._get_client().chat.completions.create(**params)
+ r = await client.chat.completions.create(**params)
else:
- r = self._get_client().completions.create(**params)
+ r = await client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
+ client = self._get_client()
+ if "messages" in params:
+ stream = await client.chat.completions.create(**params)
+ else:
+ stream = await client.completions.create(**params)
- # if we shift to TogetherAsyncClient, we won't need this wrapper
- async def _to_async_generator():
- if "messages" in params:
- s = self._get_client().chat.completions.create(**params)
- else:
- s = self._get_client().completions.create(**params)
- for chunk in s:
- yield chunk
-
- stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
@@ -240,7 +236,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings"
)
- r = self._get_client().embeddings.create(
+ client = self._get_client()
+ r = await client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)