Updated additional models to support in llama stack

This commit is contained in:
Sajikumar JS 2025-03-20 10:17:50 +05:30
parent e7b7b102cf
commit cc5bedea01
5 changed files with 109 additions and 9 deletions

View file

@ -621,5 +621,37 @@
"vllm",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"watsonx": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"ibm_watson_machine_learning"
]
}

View file

@ -194,6 +194,7 @@ class SamplingParams(BaseModel):
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
additional_params: Optional[dict] = {}
class CheckpointQuantizationFormat(Enum):
@ -233,6 +234,7 @@ class CoreModelId(Enum):
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
llama3_405b_instruct = "llama-3-405b-instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
@ -289,6 +291,7 @@ def model_family(model_id) -> ModelFamily:
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
CoreModelId.llama3_405b_instruct,
]:
return ModelFamily.llama3
elif model_id in [

View file

@ -27,6 +27,27 @@ MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
# build_hf_repo_model_entry(
# "meta-llama/llama-3-405b-instruct",
# CoreModelId.llama3_405b_instruct.value,
# ),
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
)
]

View file

@ -61,10 +61,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config
self._project_id = self._config.project_id
self.params = {
GenParams.MAX_NEW_TOKENS: 4096,
GenParams.STOP_SEQUENCES: ["<|endoftext|>"]
}
async def initialize(self) -> None:
pass
@ -210,17 +206,41 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if request.sampling_params.additional_params.get("top_p"):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"]
if request.sampling_params.additional_params.get("top_k"):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"]
if request.sampling_params.additional_params.get("temperature"):
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
if request.sampling_params.additional_params.get("length_penalty"):
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params["length_penalty"]
if request.sampling_params.additional_params.get("random_seed"):
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
if request.sampling_params.additional_params.get("min_new_tokens"):
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params["min_new_tokens"]
if request.sampling_params.additional_params.get("stop_sequences"):
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params["stop_sequences"]
if request.sampling_params.additional_params.get("time_limit"):
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
if request.sampling_params.additional_params.get("truncate_input_tokens"):
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params["truncate_input_tokens"]
if request.sampling_params.additional_params.get("return_options"):
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params["return_options"]
params = {
**input_dict,

View file

@ -95,6 +95,30 @@ models:
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-1b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-3b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-90b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-405b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-405b-instruct
- metadata: {}
model_id: meta-llama/llama-guard-3-11b-vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
shields: []
vector_dbs: []
datasets: []