diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 9ba00b8f6..dfc9ae6d3 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,7 +6,7 @@ from typing import AsyncGenerator, List, Optional, Union -from together import AsyncTogether +from together import Together from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -91,7 +91,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi else: return await self._nonstream_completion(request) - async def _get_client(self) -> AsyncTogether: + def _get_client(self) -> Together: together_api_key = None config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None if config_api_key: @@ -103,18 +103,23 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }' ) together_api_key = provider_data.together_api_key - return AsyncTogether(api_key=together_api_key) + return Together(api_key=together_api_key) async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) - client = await self._get_client() - r = await client.completions.create(**params) + r = self._get_client().completions.create(**params) return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - client = await self._get_client() - stream = await client.completions.create(**params) + + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + s = self._get_client().completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() async for chunk in process_completion_stream_response(stream): yield chunk @@ -179,21 +184,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) - client = await self._get_client() if "messages" in params: - r = await client.chat.completions.create(**params) + r = self._get_client().chat.completions.create(**params) else: - r = await client.completions.create(**params) + r = self._get_client().completions.create(**params) return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - client = await self._get_client() - if "messages" in params: - stream = await client.chat.completions.create(**params) - else: - stream = await client.completions.create(**params) + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + if "messages" in params: + s = self._get_client().chat.completions.create(**params) + else: + s = self._get_client().completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() async for chunk in process_chat_completion_stream_response(stream, request): yield chunk @@ -231,8 +240,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi assert all(not content_has_media(content) for content in contents), ( "Together does not support media for embeddings" ) - client = await self._get_client() - r = await client.embeddings.create( + r = self._get_client().embeddings.create( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], )