mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
update nvidia inference provider to use model_store
This commit is contained in:
parent
2ae1d7f4e6
commit
bb4ff1dd1f
1 changed files with 16 additions and 8 deletions
|
@ -126,6 +126,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
return _get_client_for_base_url(base_url)
|
return _get_client_for_base_url(base_url)
|
||||||
|
|
||||||
|
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||||
|
if not self.model_store:
|
||||||
|
raise RuntimeError("Model store is not set")
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model {model_id} is unknown")
|
||||||
|
return model.provider_model_id
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -144,7 +152,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
# removing this health check as NeMo customizer endpoint health check is returning 404
|
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||||
# await check_health(self._config) # this raises errors
|
# await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
provider_model_id = self.get_provider_model_id(model_id)
|
provider_model_id = await self._get_provider_model_id(model_id)
|
||||||
request = convert_completion_request(
|
request = convert_completion_request(
|
||||||
request=CompletionRequest(
|
request=CompletionRequest(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
|
@ -188,7 +196,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
#
|
#
|
||||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||||
model = self.get_provider_model_id(model_id)
|
provider_model_id = await self._get_provider_model_id(model_id)
|
||||||
|
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
|
|
||||||
|
@ -211,8 +219,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
extra_body["input_type"] = task_type_options[task_type]
|
extra_body["input_type"] = task_type_options[task_type]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._get_client(model).embeddings.create(
|
response = await self._get_client(provider_model_id).embeddings.create(
|
||||||
model=model,
|
model=provider_model_id,
|
||||||
input=input,
|
input=input,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
@ -246,10 +254,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
# await check_health(self._config) # this raises errors
|
# await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
provider_model_id = self.get_provider_model_id(model_id)
|
provider_model_id = await self._get_provider_model_id(model_id)
|
||||||
request = await convert_chat_completion_request(
|
request = await convert_chat_completion_request(
|
||||||
request=ChatCompletionRequest(
|
request=ChatCompletionRequest(
|
||||||
model=self.get_provider_model_id(model_id),
|
model=provider_model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -294,7 +302,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: Optional[List[str]] = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
provider_model_id = self.get_provider_model_id(model)
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
|
@ -347,7 +355,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
provider_model_id = self.get_provider_model_id(model)
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue