mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
update way of checking available models
This commit is contained in:
parent
f844e18d47
commit
d739fe77a9
2 changed files with 23 additions and 18 deletions
|
@ -178,6 +178,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
|
||||||
def __init__(self, config: SambaNovaImplConfig):
|
def __init__(self, config: SambaNovaImplConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.environment_available_models = []
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
model_entries=MODEL_ENTRIES,
|
model_entries=MODEL_ENTRIES,
|
||||||
|
@ -250,15 +251,18 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
list_models_url = self.config.url + "/models"
|
list_models_url = self.config.url + "/models"
|
||||||
try:
|
if len(self.environment_available_models) == 0:
|
||||||
response = requests.get(list_models_url)
|
try:
|
||||||
response.raise_for_status()
|
response = requests.get(list_models_url)
|
||||||
except requests.exceptions.RequestException as e:
|
response.raise_for_status()
|
||||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
except requests.exceptions.RequestException as e:
|
||||||
available_models = [model.get("id") for model in response.json().get("data", {})]
|
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||||
if len(available_models) == 0 or model_id.split("sambanova/")[-1] not in available_models:
|
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||||
logger.warning(f"Model {model_id} not available in {self.config.url}/models")
|
|
||||||
|
if model_id.split("sambanova/")[-1] not in self.environment_available_models:
|
||||||
|
logger.warning(f"Model {model_id} not available in {list_models_url}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
|
|
@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
||||||
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.environment_available_models = []
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
list_models_url = self.config.url + "/models"
|
list_models_url = self.config.url + "/models"
|
||||||
try:
|
if len(self.environment_available_models) == 0:
|
||||||
response = requests.get(list_models_url)
|
try:
|
||||||
response.raise_for_status()
|
response = requests.get(list_models_url)
|
||||||
except requests.exceptions.RequestException as e:
|
response.raise_for_status()
|
||||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
except requests.exceptions.RequestException as e:
|
||||||
available_models = [model.get("id") for model in response.json().get("data", {})]
|
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||||
|
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||||
if (
|
if (
|
||||||
len(available_models) == 0
|
"guard" not in shield.provider_resource_id.lower()
|
||||||
or "guard" not in shield.provider_resource_id.lower()
|
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
||||||
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
|
|
||||||
):
|
):
|
||||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {self.config.url}/models")
|
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue