mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-15 06:37:58 +00:00
fix client building
This commit is contained in:
parent
8daf2e78be
commit
8bbe13e8d1
1 changed files with 18 additions and 19 deletions
|
@ -44,7 +44,6 @@ TOGETHER_SUPPORTED_MODELS = {
|
|||
class TogetherInferenceAdapter(
|
||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
client: Together
|
||||
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
|
@ -54,20 +53,7 @@ class TogetherInferenceAdapter(
|
|||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
self.client = Together(api_key=together_api_key)
|
||||
|
||||
return
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
@ -94,11 +80,24 @@ class TogetherInferenceAdapter(
|
|||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return Together(api_key=together_api_key)
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params_for_completion(request)
|
||||
r = self.client.completions.create(**params)
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
|
@ -106,7 +105,7 @@ class TogetherInferenceAdapter(
|
|||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self.client.completions.create(**params)
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
|
@ -173,7 +172,7 @@ class TogetherInferenceAdapter(
|
|||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = self.client.completions.create(**params)
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
|
@ -183,7 +182,7 @@ class TogetherInferenceAdapter(
|
|||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self.client.completions.create(**params)
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue