fix client building

This commit is contained in:
Dinesh Yeduguru 2024-10-25 11:28:22 -07:00
parent 8daf2e78be
commit 8bbe13e8d1

View file

@ -44,7 +44,6 @@ TOGETHER_SUPPORTED_MODELS = {
class TogetherInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
client: Together
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(
@ -54,20 +53,7 @@ class TogetherInferenceAdapter(
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
self.client = Together(api_key=together_api_key)
return
pass
async def shutdown(self) -> None:
pass
@ -94,11 +80,24 @@ class TogetherInferenceAdapter(
else:
return await self._nonstream_completion(request)
def _get_client(self) -> Together:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key)
async def _nonstream_completion(
self, request: CompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_completion(request)
r = self.client.completions.create(**params)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -106,7 +105,7 @@ class TogetherInferenceAdapter(
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self.client.completions.create(**params)
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
@ -173,7 +172,7 @@ class TogetherInferenceAdapter(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = self.client.completions.create(**params)
r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
@ -183,7 +182,7 @@ class TogetherInferenceAdapter(
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self.client.completions.create(**params)
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk