ruff fix and format

This commit is contained in:
Honglin Cao 2025-01-21 16:40:08 -05:00
parent dc1ff40413
commit e20228a7ca
2 changed files with 26 additions and 26 deletions

View file

@ -8,9 +8,11 @@ from pydantic import BaseModel
from .config import CentMLImplConfig
class CentMLProviderDataValidator(BaseModel):
centml_api_key: str
async def get_adapter_impl(config: CentMLImplConfig, _deps):
"""
Factory function to construct and initialize the CentML adapter.
@ -21,9 +23,9 @@ async def get_adapter_impl(config: CentMLImplConfig, _deps):
from .centml import CentMLInferenceAdapter
# Ensure the provided config is indeed a CentMLImplConfig
assert isinstance(
config, CentMLImplConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, CentMLImplConfig), (
f"Unexpected config type: {type(config)}"
)
# Instantiate and initialize the adapter
adapter = CentMLInferenceAdapter(config)

View file

@ -42,7 +42,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
@ -65,8 +64,7 @@ MODEL_ALIASES = [
]
class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
NeedsRequestProviderData):
class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
"""
Adapter to use CentML's serverless inference endpoints,
which adhere to the OpenAI chat/completions API spec,
@ -138,14 +136,14 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest) -> ChatCompletionResponse:
self, request: CompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
# Using the older "completions" route for non-chat
response = self._get_client().completions.create(**params)
return process_completion_response(response, self.formatter)
async def _stream_completion(self,
request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
@ -154,8 +152,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(
stream, self.formatter):
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
#
@ -195,7 +192,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
# For chat requests, if "messages" is in params -> .chat.completions
@ -208,7 +206,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest) -> AsyncGenerator:
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
@ -221,7 +220,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter):
stream, self.formatter
):
yield chunk
#
@ -229,8 +229,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
#
async def _get_params(
self, request: Union[ChatCompletionRequest,
CompletionRequest]) -> dict:
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
"""
Build the 'params' dict that the OpenAI (CentML) client expects.
For chat requests, we always prefer "messages" so that it calls
@ -242,22 +242,20 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
if isinstance(request, ChatCompletionRequest):
# For chat requests, always build "messages" from the user messages
input_dict["messages"] = [
await convert_message_to_openai_dict(m)
for m in request.messages
await convert_message_to_openai_dict(m) for m in request.messages
]
else:
# Non-chat (CompletionRequest)
assert not media_present, "CentML does not support media for completions"
input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter)
request, self.formatter
)
return {
"model":
request.model,
"model": request.model,
**input_dict,
"stream":
request.stream,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
}
@ -279,8 +277,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError(
"Grammar response format not supported yet")
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
@ -297,8 +294,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
# CentML does not support media
assert all(not content_has_media(c) for c in contents), \
assert all(not content_has_media(c) for c in contents), (
"CentML does not support media for embeddings"
)
resp = self._get_client().embeddings.create(
model=model.provider_resource_id,