mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
chore: resolved comments
This commit is contained in:
parent
3d4b16b8a4
commit
a4ca4508df
1 changed files with 5 additions and 5 deletions
|
@ -94,7 +94,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _get_client(self) -> AsyncTogether:
|
def _get_client(self) -> AsyncTogether:
|
||||||
if not self._client:
|
if not self._client:
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||||
|
@ -112,7 +112,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
client = await self._get_client()
|
client = self._get_client()
|
||||||
r = await client.completions.create(**params)
|
r = await client.completions.create(**params)
|
||||||
return process_completion_response(r)
|
return process_completion_response(r)
|
||||||
|
|
||||||
|
@ -184,7 +184,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
|
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
client = await self._get_client()
|
client = self._get_client()
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = await client.chat.completions.create(**params)
|
r = await client.chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
|
@ -193,7 +193,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
client = await self._get_client()
|
client = self._get_client()
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
stream = await client.chat.completions.create(**params)
|
stream = await client.chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
|
@ -236,7 +236,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Together does not support media for embeddings"
|
"Together does not support media for embeddings"
|
||||||
)
|
)
|
||||||
client = await self._get_client()
|
client = self._get_client()
|
||||||
r = await client.embeddings.create(
|
r = await client.embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_content_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue