pre-commit fixes

This commit is contained in:
Chantal D Gama Rose 2025-03-14 13:56:05 -07:00
parent 967dd0aa08
commit 7e211f8553
314 changed files with 5574 additions and 11369 deletions

View file

@ -6,9 +6,8 @@
from typing import AsyncGenerator, List, Optional, Union
from together import Together
from together import AsyncTogether
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -32,9 +31,8 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
@ -54,27 +52,34 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
self._client = None
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
if self._client:
await self._client.close()
self._client = None
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -89,34 +94,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
else:
return await self._nonstream_completion(request)
def _get_client(self) -> Together:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key)
def _get_client(self) -> AsyncTogether:
if not self._client:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
together_api_key = config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
self._client = AsyncTogether(api_key=together_api_key)
return self._client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client().completions.create(**params)
client = self._get_client()
r = await client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
client = await self._get_client()
stream = await client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
@ -151,7 +154,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -160,6 +163,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -179,25 +184,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
r = self._get_client().chat.completions.create(**params)
r = await client.chat.completions.create(**params)
else:
r = self._get_client().completions.create(**params)
r = await client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
stream = await client.chat.completions.create(**params)
else:
stream = await client.completions.create(**params)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
if "messages" in params:
s = self._get_client().chat.completions.create(**params)
else:
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
@ -220,7 +221,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logcat.debug("inference", f"params to together: {params}")
logger.debug(f"params to together: {params}")
return params
async def embeddings(
@ -235,7 +236,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings"
)
r = self._get_client().embeddings.create(
client = self._get_client()
r = await client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)