refactor get_max_tokens and build_options

This commit is contained in:
Dinesh Yeduguru 2024-10-23 19:11:04 -07:00
parent 5965ef3979
commit 4a073fcee5
7 changed files with 33 additions and 38 deletions

View file

@ -116,7 +116,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
"model": self.map_to_provider_model(request.model), "model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request), **get_sampling_options(request.sampling_params),
} }
async def embeddings( async def embeddings(

View file

@ -116,7 +116,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if prompt.startswith("<|begin_of_text|>"): if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :] prompt = prompt[len("<|begin_of_text|>") :]
options = get_sampling_options(request) options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512) options.setdefault("max_tokens", 512)
if fmt := request.response_format: if fmt := request.response_format:

View file

@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict: def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request) sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set # This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens. # for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None: if sampling_options["max_tokens"] is not None:
@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return { return {
"model": OLLAMA_SUPPORTED_MODELS[request.model], "model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request), "options": get_sampling_options(request.sampling_params),
"raw": True, "raw": True,
"stream": request.stream, "stream": request.stream,
} }

View file

@ -90,20 +90,21 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
else: else:
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict: def _get_max_new_tokens(self, sampling_params, input_tokens):
prompt, input_tokens = completion_request_to_prompt_model_input_info( return min(
request, self.formatter sampling_params.max_tokens or (self.max_tokens - input_tokens),
)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1, self.max_tokens - input_tokens - 1,
) )
options = get_sampling_options(request) def _build_options(
self,
sampling_params: Optional[SamplingParams] = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
# delete key "max_tokens" from options since its not supported by the API # delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None) options.pop("max_tokens", None)
if fmt:
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value: if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = { options["grammar"] = {
"type": "json", "type": "json",
@ -114,13 +115,22 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
else: else:
raise ValueError(f"Unexpected response format: {fmt.type}") raise ValueError(f"Unexpected response format: {fmt.type}")
return options
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = completion_request_to_prompt_model_input_info(
request, self.formatter
)
return dict( return dict(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
details=True, details=True,
max_new_tokens=max_new_tokens, max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"], stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options, **self._build_options(request.sampling_params, request.response_format),
) )
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -129,7 +139,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params) s = await self.client.text_generation(**params)
async for chunk in s: async for chunk in s:
token_result = chunk.token token_result = chunk.token
finish_reason = None finish_reason = None
if chunk.details: if chunk.details:
@ -230,29 +239,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
prompt, input_tokens = chat_completion_request_to_model_input_info( prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter request, self.formatter
) )
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
options = get_sampling_options(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return dict( return dict(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
details=True, details=True,
max_new_tokens=max_new_tokens, max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"], stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options, **self._build_options(request.sampling_params, request.response_format),
) )
async def embeddings( async def embeddings(

View file

@ -131,7 +131,7 @@ class TogetherInferenceAdapter(
yield chunk yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request) options = get_sampling_options(request.sampling_params)
if fmt := request.response_format: if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value: if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = { options["response_format"] = {

View file

@ -143,7 +143,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
"model": VLLM_SUPPORTED_MODELS[request.model], "model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request), **get_sampling_options(request.sampling_params),
} }
async def embeddings( async def embeddings(

View file

@ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice] choices: List[OpenAICompatCompletionChoice]
def get_sampling_options(request: ChatCompletionRequest) -> dict: def get_sampling_options(params: SamplingParams) -> dict:
options = {} options = {}
if params := request.sampling_params: if params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}: for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr): if getattr(params, attr):
options[attr] = getattr(params, attr) options[attr] = getattr(params, attr)