mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-08 00:01:28 +00:00
fixes
This commit is contained in:
parent
8040f1463e
commit
5c2cf512b9
1 changed files with 11 additions and 6 deletions
|
|
@ -14,8 +14,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
|||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
|
@ -38,8 +36,9 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
InterleavedContentItem,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
|
@ -59,6 +58,8 @@ from llama_stack.models.llama.datatypes import (
|
|||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
|
@ -391,7 +392,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
return self._streaming_completion(content, converted_sampling_params)
|
||||
else:
|
||||
streaming_result = None
|
||||
async for streaming_result in self._streaming_completion(content, converted_sampling_params):
|
||||
async for _ in self._streaming_completion(content, converted_sampling_params):
|
||||
pass
|
||||
return CompletionResponse(
|
||||
content=streaming_result.delta,
|
||||
|
|
@ -402,7 +403,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent], # type: ignore
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -410,7 +414,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message], # type: ignore
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None, # type: ignore
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -419,6 +423,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue