mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Remote vLLM client is now initialized directly prior to first use instead of VLLMInferenceAdapter.initialize.
This commit is contained in:
parent
b5a9ef4c6d
commit
f1fd382d51
1 changed files with 22 additions and 8 deletions
|
@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
pass
|
||||||
self.client = AsyncOpenAI(
|
|
||||||
base_url=self.config.url,
|
|
||||||
api_key=self.config.api_token,
|
|
||||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
raise ValueError("Model store not set")
|
raise ValueError("Model store not set")
|
||||||
return await self.model_store.get_model(model_id)
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
def _lazy_initialize_client(self):
|
||||||
|
if self.client is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||||
|
self.client = self._create_client()
|
||||||
|
|
||||||
|
def _create_client(self):
|
||||||
|
return AsyncOpenAI(
|
||||||
|
base_url=self.config.url,
|
||||||
|
api_key=self.config.api_token,
|
||||||
|
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||||
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
@ -357,9 +368,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
assert self.client is not None
|
client = self._create_client() if self.client is None else self.client
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
res = await self.client.models.list()
|
res = await client.models.list()
|
||||||
available_models = [m.id async for m in res]
|
available_models = [m.id async for m in res]
|
||||||
if model.provider_resource_id not in available_models:
|
if model.provider_resource_id not in available_models:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -410,6 +421,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
self._lazy_initialize_client()
|
||||||
assert self.client is not None
|
assert self.client is not None
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
|
||||||
|
@ -449,6 +461,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: Optional[List[str]] = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
|
self._lazy_initialize_client()
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
|
|
||||||
extra_body: Dict[str, Any] = {}
|
extra_body: Dict[str, Any] = {}
|
||||||
|
@ -505,6 +518,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=model_obj.provider_resource_id,
|
model=model_obj.provider_resource_id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue