diff --git a/llama_stack/providers/remote/inference/centml/centml.py b/llama_stack/providers/remote/inference/centml/centml.py index 767c6dc37..c5b400d55 100644 --- a/llama_stack/providers/remote/inference/centml/centml.py +++ b/llama_stack/providers/remote/inference/centml/centml.py @@ -64,9 +64,8 @@ MODEL_ALIASES = [ ] -class CentMLInferenceAdapter( - ModelRegistryHelper, Inference, NeedsRequestProviderData -): +class CentMLInferenceAdapter(ModelRegistryHelper, Inference, + NeedsRequestProviderData): """ Adapter to use CentML's serverless inference endpoints, which adhere to the OpenAI chat/completions API spec, @@ -138,16 +137,14 @@ class CentMLInferenceAdapter( return await self._nonstream_completion(request) async def _nonstream_completion( - self, request: CompletionRequest - ) -> ChatCompletionResponse: + self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) # Using the older "completions" route for non-chat response = self._get_client().completions.create(**params) return process_completion_response(response) - async def _stream_completion( - self, request: CompletionRequest - ) -> AsyncGenerator: + async def _stream_completion(self, + request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) async def _to_async_generator(): @@ -156,9 +153,7 @@ class CentMLInferenceAdapter( yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response( - stream - ): + async for chunk in process_completion_stream_response(stream): yield chunk # @@ -200,8 +195,7 @@ class CentMLInferenceAdapter( return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest - ) -> ChatCompletionResponse: + self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) # For chat requests, if "messages" is in params -> .chat.completions @@ -214,8 +208,7 @@ class CentMLInferenceAdapter( return process_chat_completion_response(response, request) async def _stream_chat_completion( - self, request: ChatCompletionRequest - ) -> AsyncGenerator: + self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) async def _to_async_generator(): @@ -228,29 +221,37 @@ class CentMLInferenceAdapter( stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( - stream, request): + stream, request): yield chunk # # HELPER METHODS # - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + async def _get_params( + self, request: Union[ChatCompletionRequest, + CompletionRequest]) -> dict: input_dict = {} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): if media_present or not llama_model: - input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] + input_dict["messages"] = [ + await convert_message_to_openai_dict(m) + for m in request.messages + ] else: - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) + input_dict["prompt"] = await chat_completion_request_to_prompt( + request, llama_model) else: input_dict["prompt"] = await completion_request_to_prompt(request) params = { - "model": request.model, + "model": + request.model, **input_dict, - "stream": request.stream, + "stream": + request.stream, **self._build_options(request.sampling_params, request.logprobs, request.response_format), } logcat.debug("inference", f"params to centml: {params}") @@ -267,6 +268,8 @@ class CentMLInferenceAdapter( if fmt.type == ResponseFormatType.json_schema.value: options["response_format"] = { "type": "json_object", + # CentML currently does not support guided decoding, + # the following setting is currently ignored by the server. "schema": fmt.json_schema, } elif fmt.type == ResponseFormatType.grammar.value: @@ -295,8 +298,7 @@ class CentMLInferenceAdapter( model = await self.model_store.get_model(model_id) # CentML does not support media for embeddings. assert all(not content_has_media(c) for c in contents), ( - "CentML does not support media for embeddings" - ) + "CentML does not support media for embeddings") resp = self._get_client().embeddings.create( model=model.provider_resource_id,