Apply formatting to source files

This commit is contained in:
Fred Reiss 2024-12-19 11:46:31 -08:00
parent 6ec9eabbeb
commit c8580d3b0c
2 changed files with 15 additions and 12 deletions

View file

@ -88,23 +88,25 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
""" """
Callback that is called when the server associates an inference endpoint Callback that is called when the server associates an inference endpoint
with an inference provider. with an inference provider.
:param model: Object that encapsulates parameters necessary for identifying :param model: Object that encapsulates parameters necessary for identifying
a specific LLM. a specific LLM.
:returns: The input ``Model`` object. It may or may not be permissible :returns: The input ``Model`` object. It may or may not be permissible
to change fields before returning this object. to change fields before returning this object.
""" """
log.info(f"Registering model {model.identifier} with vLLM inference provider.") log.info(f"Registering model {model.identifier} with vLLM inference provider.")
# The current version of this provided is hard-coded to serve only # The current version of this provided is hard-coded to serve only
# the model specified in the YAML config file. # the model specified in the YAML config file.
configured_model = resolve_model(self.config.model) configured_model = resolve_model(self.config.model)
registered_model = resolve_model(model.model_id) registered_model = resolve_model(model.model_id)
if configured_model.core_model_id != registered_model.core_model_id: if configured_model.core_model_id != registered_model.core_model_id:
raise ValueError(f"Requested model '{model.identifier}' is different from " raise ValueError(
f"model '{self.config.model}' that this provider " f"Requested model '{model.identifier}' is different from "
f"is configured to serve") f"model '{self.config.model}' that this provider "
f"is configured to serve"
)
return model return model
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
@ -169,8 +171,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params) log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid() request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(request, self.config.model, prompt = await chat_completion_request_to_prompt(
self.formatter) request, self.config.model, self.formatter
)
vllm_sampling_params = self._sampling_params(request.sampling_params) vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate( results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id prompt, vllm_sampling_params, request_id
@ -226,7 +229,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
stream, self.formatter stream, self.formatter
): ):
yield chunk yield chunk
async def embeddings( async def embeddings(
self, model_id: str, contents: List[InterleavedContent] self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:

View file

@ -292,6 +292,6 @@ async def inference_stack(request, inference_model):
# Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended # Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended
yield test_stack.impls[Api.inference], test_stack.impls[Api.models] yield test_stack.impls[Api.inference], test_stack.impls[Api.models]
# Cleanup code that runs after test case completion # Cleanup code that runs after test case completion
await test_stack.impls[Api.inference].shutdown() await test_stack.impls[Api.inference].shutdown()