fix top k, add in comments

This commit is contained in:
Honglin Cao 2025-03-05 19:38:05 -05:00
parent 941d5f1b18
commit d1f67d90ca

View file

@ -64,9 +64,8 @@ MODEL_ALIASES = [
] ]
class CentMLInferenceAdapter( class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
ModelRegistryHelper, Inference, NeedsRequestProviderData NeedsRequestProviderData):
):
""" """
Adapter to use CentML's serverless inference endpoints, Adapter to use CentML's serverless inference endpoints,
which adhere to the OpenAI chat/completions API spec, which adhere to the OpenAI chat/completions API spec,
@ -138,16 +137,14 @@ class CentMLInferenceAdapter(
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest self, request: CompletionRequest) -> ChatCompletionResponse:
) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
# Using the older "completions" route for non-chat # Using the older "completions" route for non-chat
response = self._get_client().completions.create(**params) response = self._get_client().completions.create(**params)
return process_completion_response(response) return process_completion_response(response)
async def _stream_completion( async def _stream_completion(self,
self, request: CompletionRequest request: CompletionRequest) -> AsyncGenerator:
) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():
@ -156,9 +153,7 @@ class CentMLInferenceAdapter(
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_completion_stream_response( async for chunk in process_completion_stream_response(stream):
stream
):
yield chunk yield chunk
# #
@ -200,8 +195,7 @@ class CentMLInferenceAdapter(
return await self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest) -> ChatCompletionResponse:
) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
# For chat requests, if "messages" is in params -> .chat.completions # For chat requests, if "messages" is in params -> .chat.completions
@ -214,8 +208,7 @@ class CentMLInferenceAdapter(
return process_chat_completion_response(response, request) return process_chat_completion_response(response, request)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest) -> AsyncGenerator:
) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():
@ -228,29 +221,37 @@ class CentMLInferenceAdapter(
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response( async for chunk in process_chat_completion_stream_response(
stream, request): stream, request):
yield chunk yield chunk
# #
# HELPER METHODS # HELPER METHODS
# #
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: async def _get_params(
self, request: Union[ChatCompletionRequest,
CompletionRequest]) -> dict:
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model) llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model: if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] input_dict["messages"] = [
await convert_message_to_openai_dict(m)
for m in request.messages
]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) input_dict["prompt"] = await chat_completion_request_to_prompt(
request, llama_model)
else: else:
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)
params = { params = {
"model": request.model, "model":
request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream":
request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format), **self._build_options(request.sampling_params, request.logprobs, request.response_format),
} }
logcat.debug("inference", f"params to centml: {params}") logcat.debug("inference", f"params to centml: {params}")
@ -267,6 +268,8 @@ class CentMLInferenceAdapter(
if fmt.type == ResponseFormatType.json_schema.value: if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = { options["response_format"] = {
"type": "json_object", "type": "json_object",
# CentML currently does not support guided decoding,
# the following setting is currently ignored by the server.
"schema": fmt.json_schema, "schema": fmt.json_schema,
} }
elif fmt.type == ResponseFormatType.grammar.value: elif fmt.type == ResponseFormatType.grammar.value:
@ -295,8 +298,7 @@ class CentMLInferenceAdapter(
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
# CentML does not support media for embeddings. # CentML does not support media for embeddings.
assert all(not content_has_media(c) for c in contents), ( assert all(not content_has_media(c) for c in contents), (
"CentML does not support media for embeddings" "CentML does not support media for embeddings")
)
resp = self._get_client().embeddings.create( resp = self._get_client().embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,