diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 2e0df58b8..059bb873f 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -5,6 +5,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl | Name | File | Purpose | | ---- | ---- | ------- | | Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md | +| API Conformance Tests | [conformance.yml](conformance.yml) | Run the API Conformance test suite on the changes. | | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | | Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | | SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml new file mode 100644 index 000000000..2433b0203 --- /dev/null +++ b/.github/workflows/conformance.yml @@ -0,0 +1,57 @@ +# API Conformance Tests +# This workflow ensures that API changes maintain backward compatibility and don't break existing integrations +# It runs schema validation and OpenAPI diff checks to catch breaking changes early + +name: API Conformance Tests + +run-name: Run the API Conformance test suite on the changes. + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + types: [opened, synchronize, reopened] + paths: + - 'llama_stack/**' + - '!llama_stack/ui/**' + - 'tests/**' + - 'uv.lock' + - 'pyproject.toml' + - '.github/workflows/conformance.yml' # This workflow itself + +concurrency: + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} + # Cancel in-progress runs when new commits are pushed to avoid wasting CI resources + cancel-in-progress: true + +jobs: + # Job to check if API schema changes maintain backward compatibility + check-schema-compatibility: + runs-on: ubuntu-latest + steps: + # Using specific version 4.1.7 because 5.0.0 fails when trying to run this locally using `act` + # This ensures consistent behavior between local testing and CI + - name: Checkout PR Code + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + + # Checkout the base branch to compare against (usually main) + # This allows us to diff the current changes against the previous state + - name: Checkout Base Branch + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + ref: ${{ github.event.pull_request.base.ref }} + path: 'base' + + # Install oasdiff: https://github.com/oasdiff/oasdiff, a tool for detecting breaking changes in OpenAPI specs. + - name: Install oasdiff + run: | + curl -fsSL https://raw.githubusercontent.com/oasdiff/oasdiff/main/install.sh | sh + + # Run oasdiff to detect breaking changes in the API specification + # This step will fail if incompatible changes are detected, preventing breaking changes from being merged + - name: Run OpenAPI Breaking Change Diff + run: | + oasdiff breaking --fail-on ERR base/docs/_static/llama-stack-spec.yaml docs/_static/llama-stack-spec.yaml --match-path '^/v1/openai/v1' \ + --match-path '^/v1/vector-io' \ + --match-path '^/v1/vector-dbs' diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 7a95fd089..4176f85a6 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -248,7 +248,7 @@ Available Models: api=Api.inference, adapter=AdapterSpec( adapter_type="groq", - pip_packages=["litellm"], + pip_packages=["litellm", "openai"], module="llama_stack.providers.remote.inference.groq", config_class="llama_stack.providers.remote.inference.groq.GroqConfig", provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index fd7212de4..888953af0 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -4,30 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncIterator -from typing import Any -from openai import AsyncOpenAI - -from llama_stack.apis.inference import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAIChoiceDelta, - OpenAIChunkChoice, - OpenAIMessageParam, - OpenAIResponseFormatParam, - OpenAISystemMessageParam, -) from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import ( - prepare_openai_completion_params, -) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES -class GroqInferenceAdapter(LiteLLMOpenAIMixin): +class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): _config: GroqConfig def __init__(self, config: GroqConfig): @@ -40,122 +25,14 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: + return f"{self.config.url}/openai/v1" + async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() - - def _get_openai_client(self) -> AsyncOpenAI: - return AsyncOpenAI( - base_url=f"{self.config.url}/openai/v1", - api_key=self.get_api_key(), - ) - - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_obj = await self.model_store.get_model(model) - - # Groq does not support json_schema response format, so we need to convert it to json_object - if response_format and response_format.type == "json_schema": - response_format.type = "json_object" - schema = response_format.json_schema.get("schema", {}) - response_format.json_schema = None - json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}" - if messages and messages[0].role == "system": - messages[0].content = messages[0].content + json_instructions - else: - messages.insert(0, OpenAISystemMessageParam(content=json_instructions)) - - # Groq returns a 400 error if tools are provided but none are called - # So, set tool_choice to "required" to attempt to force a call - if tools and (not tool_choice or tool_choice == "auto"): - tool_choice = "required" - - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - - # Groq does not support streaming requests that set response_format - fake_stream = False - if stream and response_format: - params["stream"] = False - fake_stream = True - - response = await self._get_openai_client().chat.completions.create(**params) - - if fake_stream: - chunk_choices = [] - for choice in response.choices: - delta = OpenAIChoiceDelta( - content=choice.message.content, - role=choice.message.role, - tool_calls=choice.message.tool_calls, - ) - chunk_choice = OpenAIChunkChoice( - delta=delta, - finish_reason=choice.finish_reason, - index=choice.index, - logprobs=None, - ) - chunk_choices.append(chunk_choice) - chunk = OpenAIChatCompletionChunk( - id=response.id, - choices=chunk_choices, - object="chat.completion.chunk", - created=response.created, - model=response.model, - ) - - async def _fake_stream_generator(): - yield chunk - - return _fake_stream_generator() - else: - return response diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index b371cf907..f4b3201e9 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -33,8 +33,7 @@ def test_groq_provider_openai_client_caching(): with request_provider_data_context( {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} ): - openai_client = inference_adapter._get_openai_client() - assert openai_client.api_key == api_key + assert inference_adapter.client.api_key == api_key def test_openai_provider_openai_client_caching():