mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
fix client building
This commit is contained in:
parent
8daf2e78be
commit
8bbe13e8d1
1 changed files with 18 additions and 19 deletions
|
|
@ -44,7 +44,6 @@ TOGETHER_SUPPORTED_MODELS = {
|
||||||
class TogetherInferenceAdapter(
|
class TogetherInferenceAdapter(
|
||||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
):
|
):
|
||||||
client: Together
|
|
||||||
|
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
|
|
@ -54,20 +53,7 @@ class TogetherInferenceAdapter(
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
together_api_key = None
|
pass
|
||||||
if self.config.api_key is not None:
|
|
||||||
together_api_key = self.config.api_key
|
|
||||||
else:
|
|
||||||
provider_data = self.get_request_provider_data()
|
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
|
||||||
raise ValueError(
|
|
||||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
|
||||||
)
|
|
||||||
together_api_key = provider_data.together_api_key
|
|
||||||
|
|
||||||
self.client = Together(api_key=together_api_key)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
@ -94,11 +80,24 @@ class TogetherInferenceAdapter(
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
def _get_client(self) -> Together:
|
||||||
|
together_api_key = None
|
||||||
|
if self.config.api_key is not None:
|
||||||
|
together_api_key = self.config.api_key
|
||||||
|
else:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
together_api_key = provider_data.together_api_key
|
||||||
|
return Together(api_key=together_api_key)
|
||||||
|
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params_for_completion(request)
|
params = self._get_params_for_completion(request)
|
||||||
r = self.client.completions.create(**params)
|
r = self._get_client().completions.create(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
|
@ -106,7 +105,7 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
s = self.client.completions.create(**params)
|
s = self._get_client().completions.create(**params)
|
||||||
for chunk in s:
|
for chunk in s:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
@ -173,7 +172,7 @@ class TogetherInferenceAdapter(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
r = self.client.completions.create(**params)
|
r = self._get_client().completions.create(**params)
|
||||||
return process_chat_completion_response(r, self.formatter)
|
return process_chat_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
|
|
@ -183,7 +182,7 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
s = self.client.completions.create(**params)
|
s = self._get_client().completions.create(**params)
|
||||||
for chunk in s:
|
for chunk in s:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue