mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 16:22:43 +00:00
Merge branch 'main' into add-nvidia-inference-adapter
This commit is contained in:
commit
5fbfb9d854
92 changed files with 2145 additions and 678 deletions
|
|
@ -5,11 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
|
|
@ -57,3 +55,7 @@ class BedrockBaseConfig(BaseModel):
|
|||
default=3600,
|
||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ def is_supported_safety_model(model: Model) -> bool:
|
|||
]
|
||||
|
||||
|
||||
def supported_inference_models() -> List[str]:
|
||||
def supported_inference_models() -> List[Model]:
|
||||
return [
|
||||
m.descriptor()
|
||||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
|
|
|
|||
|
|
@ -178,7 +178,9 @@ def chat_completion_request_to_messages(
|
|||
cprint(f"Could not resolve model {llama_model}", color="red")
|
||||
return request.messages
|
||||
|
||||
if model.descriptor() not in supported_inference_models():
|
||||
allowed_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in allowed_models]
|
||||
if model.descriptor() not in descriptors:
|
||||
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
||||
return request.messages
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue