chore: Together async client

This commit is contained in:
sarthakdeshpande 2025-03-10 14:11:10 +05:30
parent a9c5d3cd3d
commit 8a4ce847b8

View file

@ -1,12 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import AsyncTogether
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -59,12 +53,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config self.config = config
self._client = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass if self._client:
await self._client.close()
self._client = None
async def completion( async def completion(
self, self,
@ -91,35 +88,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
else: else:
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_client(self) -> Together: async def _get_client(self) -> AsyncTogether:
together_api_key = None if not self._client:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None together_api_key = None
if config_api_key: config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
together_api_key = config_api_key if config_api_key:
else: together_api_key = config_api_key
provider_data = self.get_request_provider_data() else:
if provider_data is None or not provider_data.together_api_key: provider_data = self.get_request_provider_data()
raise ValueError( if provider_data is None or not provider_data.together_api_key:
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}' raise ValueError(
) 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
together_api_key = provider_data.together_api_key )
return Together(api_key=together_api_key) together_api_key = provider_data.together_api_key
self._client = AsyncTogether(api_key=together_api_key)
return self._client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = self._get_client().completions.create(**params) client = await self._get_client()
r = await client.completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
client = await self._get_client()
# if we shift to TogetherAsyncClient, we won't need this wrapper stream = await client.completions.create(**params)
async def _to_async_generator():
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
@ -184,25 +178,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
client = await self._get_client()
if "messages" in params: if "messages" in params:
r = self._get_client().chat.completions.create(**params) r = await client.chat.completions.create(**params)
else: else:
r = self._get_client().completions.create(**params) r = await client.completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
client = await self._get_client()
if "messages" in params:
stream = await client.chat.completions.create(**params)
else:
stream = await client.completions.create(**params)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
if "messages" in params:
s = self._get_client().chat.completions.create(**params)
else:
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
@ -240,7 +230,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings" "Together does not support media for embeddings"
) )
r = self._get_client().embeddings.create( client = await self._get_client()
r = await client.embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
) )