From 17fdb47e5e68292020300e339042c80824af6a3c Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Fri, 20 Dec 2024 12:32:49 +1100 Subject: [PATCH] Add Llama 70B 3.3 to fireworks (#654) # What does this PR do? - Makes Llama 70B 3.3 available for fireworks ## Test Plan ```shell pip install -e . \ && llama stack build --config distributions/fireworks/build.yaml --image-type conda \ && llama stack run distributions/fireworks/run.yaml \ --port 5000 ``` ```python response = client.inference.chat_completion( model_id="Llama3.3-70B-Instruct", messages=[ {"role": "user", "content": "hello world"}, ], ) ``` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/providers/remote/inference/fireworks/config.py | 2 +- .../providers/remote/inference/fireworks/fireworks.py | 4 ++++ llama_stack/providers/utils/inference/prompt_adapter.py | 3 ++- llama_stack/templates/fireworks/run.yaml | 5 +++++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index e69926942..979e8455a 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -22,7 +22,7 @@ class FireworksImplConfig(BaseModel): ) @classmethod - def sample_run_config(cls) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: return { "url": "https://api.fireworks.ai/inference/v1", "api_key": "${env.FIREWORKS_API_KEY}", diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index d9ef57b15..975ec4893 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -65,6 +65,10 @@ MODEL_ALIASES = [ "fireworks/llama-v3p2-90b-vision-instruct", CoreModelId.llama3_2_90b_vision_instruct.value, ), + build_model_alias( + "fireworks/llama-v3p3-70b-instruct", + CoreModelId.llama3_3_70b_instruct.value, + ), build_model_alias( "fireworks/llama-guard-3-8b", CoreModelId.llama_guard_3_8b.value, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 82fcefe54..f7d2cd84e 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -282,7 +282,8 @@ def chat_completion_request_to_messages( ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format messages = augment_messages_for_tools_llama_3_1(request) - elif model.model_family == ModelFamily.llama3_2: + elif model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3): + # llama3.2 and llama3.3 models follow the same tool prompt format messages = augment_messages_for_tools_llama_3_2(request) else: messages = request.messages diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index cb31b4678..99f155a4a 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -110,6 +110,11 @@ models: provider_id: fireworks provider_model_id: fireworks/llama-v3p2-90b-vision-instruct model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: fireworks + provider_model_id: fireworks/llama-v3p3-70b-instruct + model_type: llm - metadata: {} model_id: meta-llama/Llama-Guard-3-8B provider_id: fireworks