From d5b7de3897d0329e6d98aa4ab55e8af8f18d3663 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Wed, 29 Jan 2025 14:59:40 -0500 Subject: [PATCH 1/6] Fix link to selection guide and change "docker" to "container" (#898) The current link doesn't work. Also changed docs to be consistent with https://github.com/meta-llama/llama-stack/pull/802. --- docs/source/distributions/index.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/distributions/index.md b/docs/source/distributions/index.md index f68b8a8ae..ee7f4f23c 100644 --- a/docs/source/distributions/index.md +++ b/docs/source/distributions/index.md @@ -7,9 +7,9 @@ You can run a Llama Stack server in one of the following ways: This is the simplest way to get started. Using Llama Stack as a library means you do not need to start a server. This is especially useful when you are not running inference locally and relying on an external inference service (eg. fireworks, together, groq, etc.) See [Using Llama Stack as a Library](importing_as_library) -**Docker**: +**Container**: -Another simple way to start interacting with Llama Stack is to just spin up docker which is pre-built with all the providers you need. We provide a number of pre-built Docker containers so you can start a Llama Stack server instantly. You can also build your own custom Docker container. Which distribution to choose depends on the hardware you have. See [Selection of a Distribution](distributions/selection) for more details. +Another simple way to start interacting with Llama Stack is to just spin up a container (via Docker or Podman) which is pre-built with all the providers you need. We provide a number of pre-built images so you can start a Llama Stack server instantly. You can also build your own custom container. Which distribution to choose depends on the hardware you have. See [Selection of a Distribution](selection) for more details. **Conda**: @@ -24,4 +24,5 @@ Lastly, if you have a custom or an advanced setup or you are developing on Llama importing_as_library building_distro configuration +selection ``` From 39c34dd25f9365b09000a07de5c46dbdba27e3cb Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Thu, 30 Jan 2025 07:02:12 +1100 Subject: [PATCH 2/6] [#432] Groq Provider tool call tweaks (#811) # What does this PR do? Follow up for @ashwinb's comments in https://github.com/meta-llama/llama-stack/pull/630 - [x] Contributes to issue (#432) ## Test Plan
Environment ```shell export GROQ_API_KEY= # Create environment if not already conda create --name llamastack-groq python=3.10 conda activate llamastack-groq wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml wget https://raw.githubusercontent.com/meta-llama/llama-stack/918172c7fa92522c9ebc586bdb4f386b1d9ea224/run.yaml # Build pip install -e . && llama stack build --config ./build.yaml --image-type conda # Activate built environment conda activate llamastack-groq # Test deps pip install pytest pytest_html pytest_asyncio ```
Unit tests ```shell # Setup conda activate llamastack-groq pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py -vv -k groq -s # Result llama_stack/providers/tests/inference/groq/test_groq_utils.py ....................... ========================================= 23 passed, 11 warnings in 0.06s ========================================= ```
Integration tests ```shell # Tests pytest llama_stack/providers/tests/inference/test_text_inference.py -k groq -s # Results ___________________________ TestInference.test_chat_completion_with_tool_calling[-groq] ___________________________ llama_stack/providers/tests/inference/test_text_inference.py:403: in test_chat_completion_with_tool_calling assert len(message.tool_calls) > 0 E assert 0 > 0 E + where 0 = len([]) E + where [] = CompletionMessage(role='assistant', content='{"location": "San Francisco, CA"}', stop_reason=, tool_calls=[]).tool_calls ============================================= short test summary info ============================================= FAILED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-groq] - assert 0 > 0 ======================== 1 failed, 3 passed, 5 skipped, 99 deselected, 7 warnings in 2.13s ======================== ``` (One failure as expected from 3.2 3B - re: https://github.com/meta-llama/llama-stack/pull/630#discussion_r1914056503)
## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. Co-authored-by: Ashwin Bharambe --- .../remote/inference/groq/groq_utils.py | 94 ++++++++++++++----- .../tests/inference/groq/test_groq_utils.py | 55 +++++++++++ 2 files changed, 126 insertions(+), 23 deletions(-) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index bd1a07d7c..99fa8219c 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -6,7 +6,7 @@ import json import warnings -from typing import AsyncGenerator, Literal +from typing import AsyncGenerator, Literal, Union from groq import Stream from groq.types.chat.chat_completion import ChatCompletion @@ -30,6 +30,8 @@ from groq.types.shared.function_definition import FunctionDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition +from pydantic import BaseModel + from llama_stack.apis.common.content_types import ( TextDelta, ToolCallDelta, @@ -150,15 +152,26 @@ def convert_chat_completion_response( _convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls ] - return ChatCompletionResponse( - completion_message=CompletionMessage( - tool_calls=tool_calls, - stop_reason=StopReason.end_of_message, - # Content is not optional - content="", - ), - logprobs=None, - ) + if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): + # If we couldn't parse a tool call, jsonify the tool calls and return them + return ChatCompletionResponse( + completion_message=CompletionMessage( + stop_reason=StopReason.end_of_message, + content=json.dumps(tool_calls, default=lambda x: x.model_dump()), + ), + logprobs=None, + ) + else: + # Otherwise, return tool calls as normal + return ChatCompletionResponse( + completion_message=CompletionMessage( + tool_calls=tool_calls, + stop_reason=StopReason.end_of_message, + # Content is not optional + content="", + ), + logprobs=None, + ) else: return ChatCompletionResponse( completion_message=CompletionMessage( @@ -214,15 +227,27 @@ async def convert_chat_completion_response_stream( # We assume Groq produces fully formed tool calls for each chunk tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), + if isinstance(tool_call, ToolCall): + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + ) + ) + else: + # Otherwise it's an UnparseableToolCall - return the raw tool call + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=tool_call.model_dump_json(), + parse_status=ToolCallParseStatus.failed, + ), + ) ) - ) else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -234,12 +259,35 @@ async def convert_chat_completion_response_stream( event_type = ChatCompletionResponseEventType.progress -def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall: +class UnparseableToolCall(BaseModel): + """ + A ToolCall with arguments that are not valid JSON. + Mirrors the ToolCall schema, but with arguments as a string. + """ + + call_id: str + tool_name: str + arguments: str + + +def _convert_groq_tool_call( + tool_call: ChatCompletionMessageToolCall, +) -> Union[ToolCall, UnparseableToolCall]: + """ + Convert a Groq tool call to a ToolCall. + Returns an UnparseableToolCall if the tool call is not valid JSON. + """ + try: + arguments = json.loads(tool_call.function.arguments) + except Exception as e: + return UnparseableToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + return ToolCall( call_id=tool_call.id, tool_name=tool_call.function.name, - # Note that Groq may return a string that is not valid JSON here - # So this may raise a 500 error. Going to leave this as is to see - # how big of an issue this is and what we can do about it. - arguments=json.loads(tool_call.function.arguments), + arguments=arguments, ) diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index f6f593f16..5e0797871 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -23,6 +23,7 @@ from groq.types.chat.chat_completion_message_tool_call import ( from groq.types.shared.function_definition import FunctionDefinition from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy from llama_models.llama3.api.datatypes import ToolParamDefinition +from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponseEventType, @@ -347,6 +348,26 @@ class TestConvertNonStreamChatCompletionResponse: ), ] + def test_converts_unparseable_tool_calls(self): + response = self._dummy_chat_completion_response_with_tool_call() + response.choices[0].message.tool_calls = [ + ChatCompletionMessageToolCall( + id="tool_call_id", + type="function", + function=Function( + name="log", + arguments="(number=10, base=2)", + ), + ), + ] + + converted = convert_chat_completion_response(response) + + assert ( + converted.completion_message.content + == '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]' + ) + def _dummy_chat_completion_response(self): return ChatCompletion( id="chatcmpl-123", @@ -478,6 +499,40 @@ class TestConvertStreamChatCompletionResponse: arguments={"origin": "AU", "destination": "LAX"}, ) + @pytest.mark.asyncio + async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self): + def tool_call_stream(): + chunk = self._dummy_chat_completion_chunk_with_tool_call() + chunk.choices[0].delta.tool_calls = [ + ChoiceDeltaToolCall( + index=0, + type="function", + id="tool_call_id", + function=ChoiceDeltaToolCallFunction( + name="get_flight_info", + arguments="(origin=AU, destination=LAX)", + ), + ), + ] + yield chunk + + chunk = self._dummy_chat_completion_chunk_with_tool_call() + chunk.choices[0].delta.content = None + chunk.choices[0].finish_reason = "stop" + yield chunk + + stream = tool_call_stream() + converted = convert_chat_completion_response_stream(stream) + + iter = converted.__aiter__() + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.start + assert ( + chunk.event.delta.content + == '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}' + ) + assert chunk.event.delta.parse_status == ToolCallParseStatus.failed + def _dummy_chat_completion_chunk(self): return ChatCompletionChunk( id="chatcmpl-123", From 80f20324859cc88dbb4e41c95d961c601eb0aec2 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Wed, 29 Jan 2025 21:24:22 -0800 Subject: [PATCH 3/6] Fix running stack built with base conda environment (#903) Fixes: #902 For the test verified that llama stack can run if built: * With default "base" conda environment * With new custom conda environment using `--image-name XXX` option In both cases llama stack starts fine (was failing with "base") before this patch. CC: @ashwinb Signed-off-by: Dmitry Rogozhkin --- llama_stack/cli/stack/run.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 62a45ada0..48b443524 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -140,6 +140,10 @@ class StackRun(Subcommand): return def get_conda_prefix(env_name): + # Conda "base" environment does not end with "base" in the + # prefix, so should be handled separately. + if env_name == "base": + return os.environ.get("CONDA_PREFIX") # Get conda environments info conda_env_info = json.loads( subprocess.check_output( From 6f9023d9489e79eec4379f8dc83376dd2a858ec2 Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Wed, 29 Jan 2025 21:26:04 -0800 Subject: [PATCH 4/6] create a github action for triggering client-sdk tests on new pull-request (#850) # What does this PR do? Create a new github action that runs integration tests on fireworks and together distro upon new PR **Key features:** 1) Run inference client-sdk tests on fireworks and together distro. Load distro as a library 2) Pull changes from latest github repo (llama-models) and (llama-stack-client-python) 3) output a test summary **Next steps:** - Expand the ci test action to (llama-models) and (llama-stack-client-python) repo to make sure the changes there does not break the imports in llama-stack ## Test Plan See [the job run triggered by this PR](https://github.com/meta-llama/llama-stack/actions/runs/12926663190?pr=850) --- .github/workflows/tests.yml | 69 +++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 .github/workflows/tests.yml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..ff13a4cb0 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,69 @@ +name: auto-tests + +on: + # pull_request: + workflow_dispatch: + inputs: + commit_sha: + description: 'Specific Commit SHA to trigger on' + required: false + default: $GITHUB_SHA # default to the last commit of $GITHUB_REF branch + +jobs: + test-llama-stack-as-library: + runs-on: ubuntu-latest + env: + TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} + FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} + TAVILY_SEARCH_API_KEY: ${{ secrets.TAVILY_SEARCH_API_KEY }} + strategy: + matrix: + provider: [fireworks, together] + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.commit_sha }} + + - name: Echo commit SHA + run: | + echo "Triggered on commit SHA: ${{ github.event.inputs.commit_sha }}" + git rev-parse HEAD + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt pytest + pip install -e . + + - name: Build providers + run: | + llama stack build --template ${{ matrix.provider }} --image-type venv + + - name: Install the latest llama-stack-client & llama-models packages + run: | + pip install -e git+https://github.com/meta-llama/llama-stack-client-python.git#egg=llama-stack-client + pip install -e git+https://github.com/meta-llama/llama-models.git#egg=llama-models + + - name: Run client-sdk test + working-directory: "${{ github.workspace }}" + env: + REPORT_OUTPUT: md_report.md + shell: bash + run: | + pip install --upgrade pytest-md-report + echo "REPORT_FILE=${REPORT_OUTPUT}" >> "$GITHUB_ENV" + + export INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct + LLAMA_STACK_CONFIG=./llama_stack/templates/${{ matrix.provider }}/run.yaml pytest --md-report --md-report-verbose=1 ./tests/client-sdk/inference/test_inference.py --md-report-output "$REPORT_OUTPUT" + + - name: Output reports to the job summary + if: always() + shell: bash + run: | + if [ -f "$REPORT_FILE" ]; then + echo "
Test Report for ${{ matrix.provider }} " >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + cat "$REPORT_FILE" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "
" >> $GITHUB_STEP_SUMMARY + fi From 836f47a82dfa0a234b269f21049a80d4c1d7c49d Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Wed, 29 Jan 2025 23:41:25 -0800 Subject: [PATCH 5/6] log probs - mark pytests as xfail for unsupported providers + add support for together (#883) # What does this PR do? 1) As per @mattf's suggestion, we want to mark the pytest as xfail for providers that do not support the functionality. In this diff, we xfail the logProbs inference tests for providers who does not support log probs. ( log probs is only supported by together, fireworks and vllm) 2) Added logProbs support for together according to their developer [doc](https://docs.together.ai/docs/logprobs). ## Test Plan 1) Together & Fireworks ``` export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/together/run.yaml /opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py ``` ``` tests/client-sdk/inference/test_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of planets in our solar system?-Earth] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of the planets that have rings around them?-Saturn] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_non_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64_url[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED ========================================================================================== 15 passed, 2 warnings in 19.46s =========================================================================================== ``` ``` export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/fireworks/run.yaml /opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py ``` All tests passed 2) Ollama - LogProbs tests are marked as xfailed. ``` tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet) tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet) ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../remote/inference/together/together.py | 16 +++++++-- .../utils/inference/openai_compat.py | 30 ++++++++++++++-- tests/client-sdk/inference/test_inference.py | 36 ++++++++++++++----- 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 8f679cb56..605b3ce97 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -161,7 +161,10 @@ class TogetherInferenceAdapter( yield chunk def _build_options( - self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + self, + sampling_params: Optional[SamplingParams], + logprobs: Optional[LogProbConfig], + fmt: ResponseFormat, ) -> dict: options = get_sampling_options(sampling_params) if fmt: @@ -175,6 +178,13 @@ class TogetherInferenceAdapter( else: raise ValueError(f"Unknown response format {fmt.type}") + if logprobs and logprobs.top_k: + if logprobs.top_k != 1: + raise ValueError( + f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided", + ) + options["logprobs"] = 1 + return options async def chat_completion( @@ -263,7 +273,9 @@ class TogetherInferenceAdapter( "model": request.model, **input_dict, "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format), + **self._build_options( + request.sampling_params, request.logprobs, request.response_format + ), } async def embeddings( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 6c93f49c0..a0fb23c97 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional, Union from llama_models.datatypes import ( GreedySamplingStrategy, @@ -121,7 +121,31 @@ def convert_openai_completion_logprobs( ) -> Optional[List[TokenLogProbs]]: if not logprobs: return None - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + if hasattr(logprobs, "top_logprobs"): + return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + + # Together supports logprobs with top_k=1 only. This means for each token position, + # they return only the logprobs for the selected token (vs. the top n most likely tokens). + # Here we construct the response by matching the selected token with the logprobs. + if logprobs.tokens and logprobs.token_logprobs: + return [ + TokenLogProbs(logprobs_by_token={token: token_lp}) + for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs) + ] + return None + + +def convert_openai_completion_logprobs_stream( + text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]] +): + if logprobs is None: + return None + if isinstance(logprobs, float): + # Adapt response from Together CompletionChoicesChunk + return [TokenLogProbs(logprobs_by_token={text: logprobs})] + if hasattr(logprobs, "top_logprobs"): + return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + return None def process_completion_response( @@ -188,7 +212,7 @@ async def process_completion_stream_response( yield CompletionResponseStreamChunk( delta=text, stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs(choice.logprobs), + logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), ) if finish_reason: if finish_reason in ["stop", "eos", "eos_token"]: diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 8ca11521c..6dff1be24 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -16,6 +16,14 @@ PROVIDER_TOOL_PROMPT_FORMAT = { "remote::fireworks": "json", } +PROVIDER_LOGPROBS_TOP_K = set( + { + "remote::together", + "remote::fireworks", + # "remote:vllm" + } +) + @pytest.fixture(scope="session") def provider_tool_format(inference_provider_type): @@ -83,8 +91,12 @@ def test_text_completion_streaming(llama_stack_client, text_model_id): assert "blue" in "".join(streamed_content).lower().strip() -@pytest.mark.skip("Most inference providers don't support log probs yet") -def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id): +def test_completion_log_probs_non_streaming( + llama_stack_client, text_model_id, inference_provider_type +): + if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: + pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") + response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=False, @@ -93,16 +105,22 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id): "max_tokens": 5, }, logprobs={ - "top_k": 3, + "top_k": 1, }, ) assert response.logprobs, "Logprobs should not be empty" - assert 1 <= len(response.logprobs) <= 5 - assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs) + assert ( + 1 <= len(response.logprobs) <= 5 + ) # each token has 1 logprob and here max_tokens=5 + assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) -@pytest.mark.skip("Most inference providers don't support log probs yet") -def test_completion_log_probs_streaming(llama_stack_client, text_model_id): +def test_completion_log_probs_streaming( + llama_stack_client, text_model_id, inference_provider_type +): + if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: + pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") + response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=True, @@ -111,7 +129,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id): "max_tokens": 5, }, logprobs={ - "top_k": 3, + "top_k": 1, }, ) streamed_content = [chunk for chunk in response] @@ -119,7 +137,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id): if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all( - len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs + len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs ) else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty" From 7fe25927954d0ac00901091e3a01d06fc0ef09c9 Mon Sep 17 00:00:00 2001 From: snova-edwardm Date: Thu, 30 Jan 2025 09:24:46 -0800 Subject: [PATCH 6/6] SambaNova supports Llama 3.3 (#905) # What does this PR do? - Fix typo - Support Llama 3.3 70B ## Test Plan Run the following scripts and obtain the test results Script ``` pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming --env SAMBANOVA_API_KEY={API_KEY} ``` Result ``` llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-sambanova] PASSED =========================================== 1 passed, 1 warning in 1.26s ============================================ ``` Script ``` pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming --env SAMBANOVA_API_KEY={API_KEY} ``` Result ``` llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-sambanova] PASSED =========================================== 1 passed, 1 warning in 0.52s ============================================ ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [N] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [Y] Ran pre-commit to handle lint / formatting issues. - [Y] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [Y] Updated relevant documentation. - [N] Wrote necessary unit or integration tests. --- docs/source/distributions/self_hosted_distro/sambanova.md | 2 +- .../providers/remote/inference/sambanova/sambanova.py | 4 ++++ llama_stack/templates/sambanova/run.yaml | 5 +++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md index 199279990..6dbc0e94e 100644 --- a/docs/source/distributions/self_hosted_distro/sambanova.md +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -44,7 +44,7 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaBova.ai](https://sambanova.ai/). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://cloud.sambanova.ai/). ## Running Llama Stack with SambaNova diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index da446567a..b601d4b3f 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -57,6 +57,10 @@ MODEL_ALIASES = [ "Meta-Llama-3.2-3B-Instruct", CoreModelId.llama3_2_3b_instruct.value, ), + build_model_alias( + "Meta-Llama-3.3-70B-Instruct", + CoreModelId.llama3_3_70b_instruct.value, + ), build_model_alias( "Llama-3.2-11B-Vision-Instruct", CoreModelId.llama3_2_11b_vision_instruct.value, diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index c63b5d217..36f07dc73 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -116,6 +116,11 @@ models: provider_id: sambanova provider_model_id: Meta-Llama-3.2-3B-Instruct model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: sambanova + provider_model_id: Meta-Llama-3.3-70B-Instruct + model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-11B-Vision-Instruct provider_id: sambanova