From cc5bedea0139d5c03a7fdfeb19eef50245afc26d Mon Sep 17 00:00:00 2001 From: Sajikumar JS Date: Thu, 20 Mar 2025 10:17:50 +0530 Subject: [PATCH] Updated additional models to support in llama stack --- distributions/dependencies.json | 32 ++++++++++++++++ llama_stack/models/llama/datatypes.py | 3 ++ .../remote/inference/watsonx/models.py | 21 ++++++++++ .../remote/inference/watsonx/watsonx.py | 38 ++++++++++++++----- llama_stack/templates/watsonx/run.yaml | 24 ++++++++++++ 5 files changed, 109 insertions(+), 9 deletions(-) diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 59b0c9e62..3defd082a 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -621,5 +621,37 @@ "vllm", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + ], + "watsonx": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "datasets", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "mcp", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pymongo", + "pypdf", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "ibm_watson_machine_learning" ] } diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index b25bf0ea9..0ae4f5382 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -194,6 +194,7 @@ class SamplingParams(BaseModel): max_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 + additional_params: Optional[dict] = {} class CheckpointQuantizationFormat(Enum): @@ -233,6 +234,7 @@ class CoreModelId(Enum): llama3_70b = "Llama-3-70B" llama3_8b_instruct = "Llama-3-8B-Instruct" llama3_70b_instruct = "Llama-3-70B-Instruct" + llama3_405b_instruct = "llama-3-405b-instruct" # Llama 3.1 family llama3_1_8b = "Llama3.1-8B" @@ -289,6 +291,7 @@ def model_family(model_id) -> ModelFamily: CoreModelId.llama3_70b, CoreModelId.llama3_8b_instruct, CoreModelId.llama3_70b_instruct, + CoreModelId.llama3_405b_instruct, ]: return ModelFamily.llama3 elif model_id in [ diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py index bded586d7..f700f4096 100644 --- a/llama_stack/providers/remote/inference/watsonx/models.py +++ b/llama_stack/providers/remote/inference/watsonx/models.py @@ -27,6 +27,27 @@ MODEL_ENTRIES = [ build_hf_repo_model_entry( "meta-llama/llama-3-2-11b-vision-instruct", CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-1b-instruct", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-3b-instruct", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-90b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + # build_hf_repo_model_entry( + # "meta-llama/llama-3-405b-instruct", + # CoreModelId.llama3_405b_instruct.value, + # ), + build_hf_repo_model_entry( + "meta-llama/llama-guard-3-11b-vision", + CoreModelId.llama_guard_3_11b_vision.value, ) + ] diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 9dda8dea5..48786b4e7 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -61,10 +61,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): self._config = config self._project_id = self._config.project_id - self.params = { - GenParams.MAX_NEW_TOKENS: 4096, - GenParams.STOP_SEQUENCES: ["<|endoftext|>"] - } async def initialize(self) -> None: pass @@ -210,17 +206,41 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: - input_dict = {} + input_dict = {"params": {}} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present or not llama_model: - input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] - else: - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: assert not media_present, "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) + if request.sampling_params: + if request.sampling_params.strategy: + input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type + if request.sampling_params.max_tokens: + input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens + if request.sampling_params.repetition_penalty: + input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty + if request.sampling_params.additional_params.get("top_p"): + input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"] + if request.sampling_params.additional_params.get("top_k"): + input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"] + if request.sampling_params.additional_params.get("temperature"): + input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"] + if request.sampling_params.additional_params.get("length_penalty"): + input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params["length_penalty"] + if request.sampling_params.additional_params.get("random_seed"): + input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"] + if request.sampling_params.additional_params.get("min_new_tokens"): + input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params["min_new_tokens"] + if request.sampling_params.additional_params.get("stop_sequences"): + input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params["stop_sequences"] + if request.sampling_params.additional_params.get("time_limit"): + input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"] + if request.sampling_params.additional_params.get("truncate_input_tokens"): + input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params["truncate_input_tokens"] + if request.sampling_params.additional_params.get("return_options"): + input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params["return_options"] params = { **input_dict, diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index 851b19810..541e61073 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -95,6 +95,30 @@ models: provider_id: watsonx provider_model_id: meta-llama/llama-3-2-11b-vision-instruct model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-1b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-1b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-3b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-90b-vision-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-405b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-405b-instruct +- metadata: {} + model_id: meta-llama/llama-guard-3-11b-vision + provider_id: watsonx + provider_model_id: meta-llama/llama-guard-3-11b-vision + model_type: llm shields: [] vector_dbs: [] datasets: []