Make vllm inference better

Tests still don't pass completely (some hang) so I think there are some
potential threading issues maybe
This commit is contained in:
Ashwin Bharambe 2024-10-24 22:30:49 -07:00
parent cb43caa2c3
commit 70d59b0f5d
2 changed files with 84 additions and 84 deletions

View file

@ -15,13 +15,24 @@ class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""
model: str = Field(
default="Llama3.1-8B-Instruct",
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
enforce_eager: bool = Field(
default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
)
gpu_memory_utilization: float = Field(
default=0.3,
)
@field_validator("model")
@classmethod