chore: Revert together changes

This commit is contained in:
sarthakdeshpande 2025-03-09 22:00:00 +05:30
parent 6551a44971
commit f770e60e6a

View file

@ -6,7 +6,7 @@
from typing import AsyncGenerator, List, Optional, Union
from together import AsyncTogether
from together import Together
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -91,7 +91,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
else:
return await self._nonstream_completion(request)
async def _get_client(self) -> AsyncTogether:
def _get_client(self) -> Together:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
@ -103,18 +103,23 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return AsyncTogether(api_key=together_api_key)
return Together(api_key=together_api_key)
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
client = await self._get_client()
r = await client.completions.create(**params)
r = self._get_client().completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
client = await self._get_client()
stream = await client.completions.create(**params)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream):
yield chunk
@ -179,21 +184,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
client = await self._get_client()
if "messages" in params:
r = await client.chat.completions.create(**params)
r = self._get_client().chat.completions.create(**params)
else:
r = await client.completions.create(**params)
r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
client = await self._get_client()
if "messages" in params:
stream = await client.chat.completions.create(**params)
else:
stream = await client.completions.create(**params)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
if "messages" in params:
s = self._get_client().chat.completions.create(**params)
else:
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
@ -231,8 +240,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings"
)
client = await self._get_client()
r = await client.embeddings.create(
r = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)