ruff fix and format

This commit is contained in:
Honglin Cao 2025-01-21 16:40:08 -05:00
parent dc1ff40413
commit e20228a7ca
2 changed files with 26 additions and 26 deletions

View file

@ -8,9 +8,11 @@ from pydantic import BaseModel
from .config import CentMLImplConfig from .config import CentMLImplConfig
class CentMLProviderDataValidator(BaseModel): class CentMLProviderDataValidator(BaseModel):
centml_api_key: str centml_api_key: str
async def get_adapter_impl(config: CentMLImplConfig, _deps): async def get_adapter_impl(config: CentMLImplConfig, _deps):
""" """
Factory function to construct and initialize the CentML adapter. Factory function to construct and initialize the CentML adapter.
@ -21,9 +23,9 @@ async def get_adapter_impl(config: CentMLImplConfig, _deps):
from .centml import CentMLInferenceAdapter from .centml import CentMLInferenceAdapter
# Ensure the provided config is indeed a CentMLImplConfig # Ensure the provided config is indeed a CentMLImplConfig
assert isinstance( assert isinstance(config, CentMLImplConfig), (
config, CentMLImplConfig f"Unexpected config type: {type(config)}"
), f"Unexpected config type: {type(config)}" )
# Instantiate and initialize the adapter # Instantiate and initialize the adapter
adapter = CentMLInferenceAdapter(config) adapter = CentMLInferenceAdapter(config)

View file

@ -42,7 +42,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
process_completion_stream_response, process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media, content_has_media,
interleaved_content_as_str, interleaved_content_as_str,
@ -65,8 +64,7 @@ MODEL_ALIASES = [
] ]
class CentMLInferenceAdapter(ModelRegistryHelper, Inference, class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
NeedsRequestProviderData):
""" """
Adapter to use CentML's serverless inference endpoints, Adapter to use CentML's serverless inference endpoints,
which adhere to the OpenAI chat/completions API spec, which adhere to the OpenAI chat/completions API spec,
@ -138,14 +136,14 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest) -> ChatCompletionResponse: self, request: CompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
# Using the older "completions" route for non-chat # Using the older "completions" route for non-chat
response = self._get_client().completions.create(**params) response = self._get_client().completions.create(**params)
return process_completion_response(response, self.formatter) return process_completion_response(response, self.formatter)
async def _stream_completion(self, async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():
@ -154,8 +152,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_completion_stream_response( async for chunk in process_completion_stream_response(stream, self.formatter):
stream, self.formatter):
yield chunk yield chunk
# #
@ -195,7 +192,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
return await self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest) -> ChatCompletionResponse: self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
# For chat requests, if "messages" is in params -> .chat.completions # For chat requests, if "messages" is in params -> .chat.completions
@ -208,7 +206,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
return process_chat_completion_response(response, self.formatter) return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest) -> AsyncGenerator: self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():
@ -221,7 +220,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response( async for chunk in process_chat_completion_stream_response(
stream, self.formatter): stream, self.formatter
):
yield chunk yield chunk
# #
@ -229,8 +229,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
# #
async def _get_params( async def _get_params(
self, request: Union[ChatCompletionRequest, self, request: Union[ChatCompletionRequest, CompletionRequest]
CompletionRequest]) -> dict: ) -> dict:
""" """
Build the 'params' dict that the OpenAI (CentML) client expects. Build the 'params' dict that the OpenAI (CentML) client expects.
For chat requests, we always prefer "messages" so that it calls For chat requests, we always prefer "messages" so that it calls
@ -242,22 +242,20 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
# For chat requests, always build "messages" from the user messages # For chat requests, always build "messages" from the user messages
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_openai_dict(m) await convert_message_to_openai_dict(m) for m in request.messages
for m in request.messages
] ]
else: else:
# Non-chat (CompletionRequest) # Non-chat (CompletionRequest)
assert not media_present, "CentML does not support media for completions" assert not media_present, "CentML does not support media for completions"
input_dict["prompt"] = await completion_request_to_prompt( input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter) request, self.formatter
)
return { return {
"model": "model": request.model,
request.model,
**input_dict, **input_dict,
"stream": "stream": request.stream,
request.stream,
**self._build_options(request.sampling_params, request.response_format), **self._build_options(request.sampling_params, request.response_format),
} }
@ -279,8 +277,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
"schema": fmt.json_schema, "schema": fmt.json_schema,
} }
elif fmt.type == ResponseFormatType.grammar.value: elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError( raise NotImplementedError("Grammar response format not supported yet")
"Grammar response format not supported yet")
else: else:
raise ValueError(f"Unknown response format {fmt.type}") raise ValueError(f"Unknown response format {fmt.type}")
@ -297,8 +294,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
# CentML does not support media # CentML does not support media
assert all(not content_has_media(c) for c in contents), \ assert all(not content_has_media(c) for c in contents), (
"CentML does not support media for embeddings" "CentML does not support media for embeddings"
)
resp = self._get_client().embeddings.create( resp = self._get_client().embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,