fix client building

This commit is contained in:
Dinesh Yeduguru 2024-10-25 11:28:22 -07:00
parent 8daf2e78be
commit 8bbe13e8d1

View file

@ -44,7 +44,6 @@ TOGETHER_SUPPORTED_MODELS = {
class TogetherInferenceAdapter( class TogetherInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData ModelRegistryHelper, Inference, NeedsRequestProviderData
): ):
client: Together
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(
@ -54,20 +53,7 @@ class TogetherInferenceAdapter(
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None: async def initialize(self) -> None:
together_api_key = None pass
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
self.client = Together(api_key=together_api_key)
return
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -94,11 +80,24 @@ class TogetherInferenceAdapter(
else: else:
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_client(self) -> Together:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key)
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest self, request: CompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params_for_completion(request) params = self._get_params_for_completion(request)
r = self.client.completions.create(**params) r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter) return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -106,7 +105,7 @@ class TogetherInferenceAdapter(
# if we shift to TogetherAsyncClient, we won't need this wrapper # if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator(): async def _to_async_generator():
s = self.client.completions.create(**params) s = self._get_client().completions.create(**params)
for chunk in s: for chunk in s:
yield chunk yield chunk
@ -173,7 +172,7 @@ class TogetherInferenceAdapter(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = self._get_params(request)
r = self.client.completions.create(**params) r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter) return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion( async def _stream_chat_completion(
@ -183,7 +182,7 @@ class TogetherInferenceAdapter(
# if we shift to TogetherAsyncClient, we won't need this wrapper # if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator(): async def _to_async_generator():
s = self.client.completions.create(**params) s = self._get_client().completions.create(**params)
for chunk in s: for chunk in s:
yield chunk yield chunk