From 8bbe13e8d1189cc9f8b910383f3d331839c3e396 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 25 Oct 2024 11:28:22 -0700 Subject: [PATCH] fix client building --- .../adapters/inference/together/together.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index f4c6a712d..8c92836f9 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -44,7 +44,6 @@ TOGETHER_SUPPORTED_MODELS = { class TogetherInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): - client: Together def __init__(self, config: TogetherImplConfig) -> None: ModelRegistryHelper.__init__( @@ -54,20 +53,7 @@ class TogetherInferenceAdapter( self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: - together_api_key = None - if self.config.api_key is not None: - together_api_key = self.config.api_key - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key - - self.client = Together(api_key=together_api_key) - - return + pass async def shutdown(self) -> None: pass @@ -94,11 +80,24 @@ class TogetherInferenceAdapter( else: return await self._nonstream_completion(request) + def _get_client(self) -> Together: + together_api_key = None + if self.config.api_key is not None: + together_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key + return Together(api_key=together_api_key) + async def _nonstream_completion( self, request: CompletionRequest ) -> ChatCompletionResponse: params = self._get_params_for_completion(request) - r = self.client.completions.create(**params) + r = self._get_client().completions.create(**params) return process_completion_response(r, self.formatter) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: @@ -106,7 +105,7 @@ class TogetherInferenceAdapter( # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): - s = self.client.completions.create(**params) + s = self._get_client().completions.create(**params) for chunk in s: yield chunk @@ -173,7 +172,7 @@ class TogetherInferenceAdapter( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: params = self._get_params(request) - r = self.client.completions.create(**params) + r = self._get_client().completions.create(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( @@ -183,7 +182,7 @@ class TogetherInferenceAdapter( # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): - s = self.client.completions.create(**params) + s = self._get_client().completions.create(**params) for chunk in s: yield chunk