This commit is contained in:
Honglin Cao 2025-01-21 17:02:28 -05:00
parent e20228a7ca
commit 6c1b1722b4
2 changed files with 35 additions and 11 deletions

View file

@ -64,7 +64,9 @@ MODEL_ALIASES = [
]
class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
class CentMLInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
"""
Adapter to use CentML's serverless inference endpoints,
which adhere to the OpenAI chat/completions API spec,
@ -143,7 +145,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
response = self._get_client().completions.create(**params)
return process_completion_response(response, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
@ -152,7 +156,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(
stream, self.formatter
):
yield chunk
#
@ -242,12 +248,15 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
if isinstance(request, ChatCompletionRequest):
# For chat requests, always build "messages" from the user messages
input_dict["messages"] = [
await convert_message_to_openai_dict(m) for m in request.messages
await convert_message_to_openai_dict(m)
for m in request.messages
]
else:
# Non-chat (CompletionRequest)
assert not media_present, "CentML does not support media for completions"
assert not media_present, (
"CentML does not support media for completions"
)
input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter
)
@ -256,7 +265,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
**self._build_options(
request.sampling_params, request.response_format
),
}
def _build_options(
@ -277,7 +288,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
raise NotImplementedError(
"Grammar response format not supported yet"
)
else:
raise ValueError(f"Unknown response format {fmt.type}")