diff --git a/llama_stack/providers/remote/inference/centml/centml.py b/llama_stack/providers/remote/inference/centml/centml.py index 85e4c95bb..aacc73804 100644 --- a/llama_stack/providers/remote/inference/centml/centml.py +++ b/llama_stack/providers/remote/inference/centml/centml.py @@ -64,7 +64,9 @@ MODEL_ALIASES = [ ] -class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): +class CentMLInferenceAdapter( + ModelRegistryHelper, Inference, NeedsRequestProviderData +): """ Adapter to use CentML's serverless inference endpoints, which adhere to the OpenAI chat/completions API spec, @@ -143,7 +145,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide response = self._get_client().completions.create(**params) return process_completion_response(response, self.formatter) - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _stream_completion( + self, request: CompletionRequest + ) -> AsyncGenerator: params = await self._get_params(request) async def _to_async_generator(): @@ -152,7 +156,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response( + stream, self.formatter + ): yield chunk # @@ -242,12 +248,15 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide if isinstance(request, ChatCompletionRequest): # For chat requests, always build "messages" from the user messages input_dict["messages"] = [ - await convert_message_to_openai_dict(m) for m in request.messages + await convert_message_to_openai_dict(m) + for m in request.messages ] else: # Non-chat (CompletionRequest) - assert not media_present, "CentML does not support media for completions" + assert not media_present, ( + "CentML does not support media for completions" + ) input_dict["prompt"] = await completion_request_to_prompt( request, self.formatter ) @@ -256,7 +265,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide "model": request.model, **input_dict, "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format), + **self._build_options( + request.sampling_params, request.response_format + ), } def _build_options( @@ -277,7 +288,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide "schema": fmt.json_schema, } elif fmt.type == ResponseFormatType.grammar.value: - raise NotImplementedError("Grammar response format not supported yet") + raise NotImplementedError( + "Grammar response format not supported yet" + ) else: raise ValueError(f"Unknown response format {fmt.type}") diff --git a/llama_stack/templates/centml/centml.py b/llama_stack/templates/centml/centml.py index 891c393dd..0f8c13b7a 100644 --- a/llama_stack/templates/centml/centml.py +++ b/llama_stack/templates/centml/centml.py @@ -13,12 +13,17 @@ from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig -from llama_stack.providers.remote.inference.centml.config import CentMLImplConfig +from llama_stack.providers.remote.inference.centml.config import ( + CentMLImplConfig, +) # If your CentML adapter has a MODEL_ALIASES constant with known model mappings: from llama_stack.providers.remote.inference.centml.centml import MODEL_ALIASES -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, +) def get_distribution_template() -> DistributionTemplate: @@ -33,7 +38,11 @@ def get_distribution_template() -> DistributionTemplate: "telemetry": ["inline::meta-reference"], "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "scoring": [ + "inline::basic", + "inline::llm-as-judge", + "inline::braintrust", + ], } name = "centml" @@ -94,7 +103,9 @@ def get_distribution_template() -> DistributionTemplate: "memory": [memory_provider], }, default_models=default_models + [embedding_model], - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_shields=[ + ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B") + ], ), }, run_config_env_vars={