refactor get_max_tokens and build_options

This commit is contained in:
Dinesh Yeduguru 2024-10-23 19:11:04 -07:00
parent 5965ef3979
commit 4a073fcee5
7 changed files with 33 additions and 38 deletions

View file

@ -90,20 +90,21 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
else:
return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = completion_request_to_prompt_model_input_info(
request, self.formatter
)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
def _get_max_new_tokens(self, sampling_params, input_tokens):
return min(
sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
options = get_sampling_options(request)
def _build_options(
self,
sampling_params: Optional[SamplingParams] = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
# delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None)
if fmt := request.response_format:
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
@ -114,13 +115,22 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return options
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = completion_request_to_prompt_model_input_info(
request, self.formatter
)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=max_new_tokens,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
**self._build_options(request.sampling_params, request.response_format),
)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -129,7 +139,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
if chunk.details:
@ -230,29 +239,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
options = get_sampling_options(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=max_new_tokens,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
**self._build_options(request.sampling_params, request.response_format),
)
async def embeddings(