This commit is contained in:
Ashwin Bharambe 2025-03-07 13:29:45 -08:00
parent 8040f1463e
commit 5c2cf512b9

View file

@ -14,8 +14,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
# fully-qualified names
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
@ -38,8 +36,9 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
InterleavedContentItem,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
@ -59,6 +58,8 @@ from llama_stack.models.llama.datatypes import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -391,7 +392,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
return self._streaming_completion(content, converted_sampling_params)
else:
streaming_result = None
async for streaming_result in self._streaming_completion(content, converted_sampling_params):
async for _ in self._streaming_completion(content, converted_sampling_params):
pass
return CompletionResponse(
content=streaming_result.delta,
@ -402,7 +403,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent], # type: ignore
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
@ -410,7 +414,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self,
model_id: str,
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -419,6 +423,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
sampling_params = sampling_params or SamplingParams()
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"