From e20228a7ca475c43e3142fff9a2e2c4f960eb63e Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 21 Jan 2025 16:40:08 -0500 Subject: [PATCH] ruff fix and format --- .../remote/inference/centml/__init__.py | 8 ++-- .../remote/inference/centml/centml.py | 44 +++++++++---------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/llama_stack/providers/remote/inference/centml/__init__.py b/llama_stack/providers/remote/inference/centml/__init__.py index 3aae4dddc..4bfc27b9e 100644 --- a/llama_stack/providers/remote/inference/centml/__init__.py +++ b/llama_stack/providers/remote/inference/centml/__init__.py @@ -8,9 +8,11 @@ from pydantic import BaseModel from .config import CentMLImplConfig + class CentMLProviderDataValidator(BaseModel): centml_api_key: str + async def get_adapter_impl(config: CentMLImplConfig, _deps): """ Factory function to construct and initialize the CentML adapter. @@ -21,9 +23,9 @@ async def get_adapter_impl(config: CentMLImplConfig, _deps): from .centml import CentMLInferenceAdapter # Ensure the provided config is indeed a CentMLImplConfig - assert isinstance( - config, CentMLImplConfig - ), f"Unexpected config type: {type(config)}" + assert isinstance(config, CentMLImplConfig), ( + f"Unexpected config type: {type(config)}" + ) # Instantiate and initialize the adapter adapter = CentMLInferenceAdapter(config) diff --git a/llama_stack/providers/remote/inference/centml/centml.py b/llama_stack/providers/remote/inference/centml/centml.py index f76763026..85e4c95bb 100644 --- a/llama_stack/providers/remote/inference/centml/centml.py +++ b/llama_stack/providers/remote/inference/centml/centml.py @@ -42,7 +42,6 @@ from llama_stack.providers.utils.inference.openai_compat import ( process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, completion_request_to_prompt, content_has_media, interleaved_content_as_str, @@ -65,8 +64,7 @@ MODEL_ALIASES = [ ] -class CentMLInferenceAdapter(ModelRegistryHelper, Inference, - NeedsRequestProviderData): +class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): """ Adapter to use CentML's serverless inference endpoints, which adhere to the OpenAI chat/completions API spec, @@ -138,14 +136,14 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, return await self._nonstream_completion(request) async def _nonstream_completion( - self, request: CompletionRequest) -> ChatCompletionResponse: + self, request: CompletionRequest + ) -> ChatCompletionResponse: params = await self._get_params(request) # Using the older "completions" route for non-chat response = self._get_client().completions.create(**params) return process_completion_response(response, self.formatter) - async def _stream_completion(self, - request: CompletionRequest) -> AsyncGenerator: + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) async def _to_async_generator(): @@ -154,8 +152,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response( - stream, self.formatter): + async for chunk in process_completion_stream_response(stream, self.formatter): yield chunk # @@ -195,7 +192,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest) -> ChatCompletionResponse: + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: params = await self._get_params(request) # For chat requests, if "messages" is in params -> .chat.completions @@ -208,7 +206,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, return process_chat_completion_response(response, self.formatter) async def _stream_chat_completion( - self, request: ChatCompletionRequest) -> AsyncGenerator: + self, request: ChatCompletionRequest + ) -> AsyncGenerator: params = await self._get_params(request) async def _to_async_generator(): @@ -221,7 +220,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( - stream, self.formatter): + stream, self.formatter + ): yield chunk # @@ -229,8 +229,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, # async def _get_params( - self, request: Union[ChatCompletionRequest, - CompletionRequest]) -> dict: + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: """ Build the 'params' dict that the OpenAI (CentML) client expects. For chat requests, we always prefer "messages" so that it calls @@ -242,22 +242,20 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, if isinstance(request, ChatCompletionRequest): # For chat requests, always build "messages" from the user messages input_dict["messages"] = [ - await convert_message_to_openai_dict(m) - for m in request.messages + await convert_message_to_openai_dict(m) for m in request.messages ] else: # Non-chat (CompletionRequest) assert not media_present, "CentML does not support media for completions" input_dict["prompt"] = await completion_request_to_prompt( - request, self.formatter) + request, self.formatter + ) return { - "model": - request.model, + "model": request.model, **input_dict, - "stream": - request.stream, + "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), } @@ -279,8 +277,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, "schema": fmt.json_schema, } elif fmt.type == ResponseFormatType.grammar.value: - raise NotImplementedError( - "Grammar response format not supported yet") + raise NotImplementedError("Grammar response format not supported yet") else: raise ValueError(f"Unknown response format {fmt.type}") @@ -297,8 +294,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) # CentML does not support media - assert all(not content_has_media(c) for c in contents), \ + assert all(not content_has_media(c) for c in contents), ( "CentML does not support media for embeddings" + ) resp = self._get_client().embeddings.create( model=model.provider_resource_id,