chore: mypy for remote::vllm

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-25 11:11:15 -04:00
parent d5b52923c1
commit 3e6c47ce10
5 changed files with 37 additions and 28 deletions

View file

@ -431,7 +431,7 @@ class Inference(Protocol):
- Embedding models: these models generate embeddings to be used for semantic search.
"""
model_store: ModelStore
model_store: ModelStore | None = None
@webmethod(route="/inference/completion", method="POST")
async def completion(

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, List, Optional, Union
import httpx
from openai import AsyncOpenAI
@ -32,11 +32,12 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response(
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compat_tools = []
for tool in tools:
@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
compat_tools.append(compat_tool)
if len(compat_tools) > 0:
return compat_tools
return None
return compat_tools
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response(
)
elif choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += tool_call.tool_name
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += tool_call.arguments
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -248,10 +245,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk]:
assert self.model_store is not None
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -270,17 +268,18 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk]:
assert self.model_store is not None
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = self.model_store.get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
@ -318,11 +317,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk]:
params = await self._get_params(request)
stream = await client.chat.completions.create(**params)
if len(request.tools) > 0:
if request.tools:
res = _process_vllm_chat_completion_stream_response(stream)
else:
res = process_chat_completion_stream_response(stream, request)
@ -330,18 +331,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
assert self.client is not None
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator[CompletionResponseStreamChunk]:
assert self.client is not None
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
async def register_model(self, model: Model) -> None:
assert self.client is not None
model = await self.register_helper.register_model(model)
res = await self.client.models.list()
available_models = [m.id async for m in res]
@ -350,14 +354,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict = {}
input_dict: dict[str, Any] = {}
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
@ -368,9 +371,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["prompt"] = await completion_request_to_prompt(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
elif fmt.type == ResponseFormatType.grammar.value:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
@ -393,7 +396,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert self.client is not None
model = self.model_store.get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding

View file

@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
)
return model
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
return options
def get_sampling_options(params: SamplingParams) -> dict:
def get_sampling_options(params: SamplingParams | None) -> dict:
if not params:
return {}
options = {}
if params:
options.update(get_sampling_strategy_options(params))

View file

@ -253,7 +253,6 @@ exclude = [
"^llama_stack/providers/remote/inference/sample/",
"^llama_stack/providers/remote/inference/tgi/",
"^llama_stack/providers/remote/inference/together/",
"^llama_stack/providers/remote/inference/vllm/",
"^llama_stack/providers/remote/safety/bedrock/",
"^llama_stack/providers/remote/safety/nvidia/",
"^llama_stack/providers/remote/safety/sample/",