mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
update register model method
This commit is contained in:
parent
bfd1ee6951
commit
fc5ad35e0d
1 changed files with 15 additions and 0 deletions
|
@ -7,6 +7,7 @@
|
|||
import json
|
||||
from collections.abc import Iterable
|
||||
|
||||
import requests
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
)
|
||||
|
@ -56,6 +57,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
@ -246,6 +248,19 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
list_models_url = self.config.url + "/models"
|
||||
try:
|
||||
response = requests.get(list_models_url)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||
available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||
if len(available_models) == 0 or model_id.split("sambanova/")[-1] not in available_models:
|
||||
logger.warning(f"Model {model_id} not found as available in SambaNova models")
|
||||
return model
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue