From b6500974eca169ed053a7b95408ac756c160c004 Mon Sep 17 00:00:00 2001 From: Kai Wu Date: Tue, 3 Dec 2024 20:11:19 -0800 Subject: [PATCH 1/4] removed assertion in ollama.py and fixed typo in the readme (#563) # What does this PR do? 1. removed [incorrect assertion](https://github.com/meta-llama/llama-stack/blob/435f34b05e84f1747b28570234f25878cf0b31c4/llama_stack/providers/remote/inference/ollama/ollama.py#L183) in ollama.py 2. fixed a typo in [this line](https://github.com/meta-llama/llama-stack/blob/435f34b05e84f1747b28570234f25878cf0b31c4/docs/source/distributions/importing_as_library.md?plain=1#L24), as `model=` should be `model_id=` . - [x] Addresses issue ([#issue562](https://github.com/meta-llama/llama-stack/issues/562)) ## Test Plan tested with code: ```python import asyncio import os # pip install aiosqlite ollama faiss from llama_stack_client.lib.direct.direct import LlamaStackDirectClient from llama_stack_client.types import SystemMessage, UserMessage async def main(): os.environ["INFERENCE_MODEL"] = "meta-llama/Llama-3.2-1B-Instruct" client = await LlamaStackDirectClient.from_template("ollama") await client.initialize() response = await client.models.list() print(response) model_name = response[0].identifier response = await client.inference.chat_completion( messages=[ SystemMessage(content="You are a friendly assistant.", role="system"), UserMessage( content="hello world, write me a 2 sentence poem about the moon", role="user", ), ], model_id=model_name, stream=False, ) print("\nChat completion response:") print(response, type(response)) asyncio.run(main()) ``` OUTPUT: ``` python test.py Using template ollama with config: apis: - agents - inference - memory - safety - telemetry conda_env: ollama datasets: [] docker_image: null eval_tasks: [] image_name: ollama memory_banks: [] metadata_store: db_path: /Users/kaiwu/.llama/distributions/ollama/registry.db namespace: null type: sqlite models: - metadata: {} model_id: meta-llama/Llama-3.2-1B-Instruct provider_id: ollama provider_model_id: null providers: agents: - config: persistence_store: db_path: /Users/kaiwu/.llama/distributions/ollama/agents_store.db namespace: null type: sqlite provider_id: meta-reference provider_type: inline::meta-reference inference: - config: url: http://localhost:11434 provider_id: ollama provider_type: remote::ollama memory: - config: kvstore: db_path: /Users/kaiwu/.llama/distributions/ollama/faiss_store.db namespace: null type: sqlite provider_id: faiss provider_type: inline::faiss safety: - config: {} provider_id: llama-guard provider_type: inline::llama-guard telemetry: - config: {} provider_id: meta-reference provider_type: inline::meta-reference scoring_fns: [] shields: [] version: '2' [Model(identifier='meta-llama/Llama-3.2-1B-Instruct', provider_resource_id='llama3.2:1b-instruct-fp16', provider_id='ollama', type='model', metadata={})] Chat completion response: completion_message=CompletionMessage(role='assistant', content='Here is a short poem about the moon:\n\nThe moon glows bright in the midnight sky,\nA silver crescent shining, catching the eye.', stop_reason=, tool_calls=[]) logprobs=None ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- docs/source/distributions/importing_as_library.md | 2 +- llama_stack/providers/remote/inference/ollama/ollama.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index 815660fd4..7e15062df 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -21,7 +21,7 @@ print(response) ```python response = await client.inference.chat_completion( messages=[UserMessage(content="What is the capital of France?", role="user")], - model="Llama3.1-8B-Instruct", + model_id="Llama3.1-8B-Instruct", stream=False, ) print("\nChat completion response:") diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 74c0b8601..f89629afc 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -180,7 +180,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) r = await self.client.generate(**params) - assert isinstance(r, dict) choice = OpenAICompatCompletionChoice( finish_reason=r["done_reason"] if r["done"] else None, From 64c6df8392c8ceea321375bca12af2b025f6693e Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Wed, 4 Dec 2024 00:15:32 -0500 Subject: [PATCH 2/4] Cerebras Inference Integration (#265) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adding Cerebras Inference as an API provider. ## Testing ### Conda ``` $ llama stack build --template cerebras --image-type conda $ llama stack run ~/.llama/distributions/llamastack-cerebras/cerebras-run.yaml ... Listening on ['::', '0.0.0.0']:5000 INFO: Started server process [12443] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit) ``` ### Chat Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/chat-completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "messages": [ { "role": "user", "content": "What is the temperature in Seattle right now?" } ], "stream": false, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 100 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "completion_message": { "role": "assistant", "content": "", "stop_reason": "end_of_message", "tool_calls": [ { "call_id": "6f42fdcc-6cbb-46ad-a17b-5d20ac64b678", "tool_name": "getTemperature", "arguments": { "location": "Seattle" } } ] }, "logprobs": null } ``` #### Streaming Response ``` data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"","parse_status":"started"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"{\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"type","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"function","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"name","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"get","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Temperature","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"parameters","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" {\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"location","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Seattle","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\"}}","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"e742df1f-0ae9-40ad-a49e-18e5c905484f","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"success"},"logprobs":null,"stop_reason":"end_of_message"}} data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_message"}} ``` ### Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "content": "1,2,3,", "stream": true, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 10 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "content": "4,5,6,7,8,", "stop_reason": "out_of_tokens", "logprobs": null } ``` #### Streaming Response ``` data: {"delta":"4","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"5","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"6","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"7","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"8","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":"out_of_tokens","logprobs":null} ``` ### Pre-Commit Checks ``` trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with ยตfmt...................................................Passed ``` ### Testing with `test_inference.py` ``` $ export CEREBRAS_API_KEY= $ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "cerebras and llama_8b" /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =================================================== test session starts =================================================== platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12 cachedir: .pytest_cache rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack configfile: pyproject.toml plugins: anyio-4.6.2.post1, asyncio-0.24.0 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 128 items / 120 deselected / 8 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-cerebras] Resolved 4 providers inner-inference => cerebras models => __routing_table__ inference => __autorouted__ inspect => __builtin__ Models: meta-llama/Llama-3.1-8B-Instruct served by cerebras PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-cerebras] PASSED ================================ 6 passed, 2 skipped, 120 deselected, 6 warnings in 3.95s ================================= ``` I ran `python llama_stack/scripts/distro_codegen.py` to run codegen. --- README.md | 2 + distributions/cerebras/build.yaml | 1 + distributions/cerebras/compose.yaml | 16 + distributions/cerebras/run.yaml | 1 + distributions/dependencies.json | 380 ++++++++++-------- docs/source/distributions/building_distro.md | 356 ++++++++++------ .../self_hosted_distro/cerebras.md | 61 +++ docs/source/index.md | 1 + llama_stack/providers/registry/inference.py | 11 + .../remote/inference/cerebras/__init__.py | 21 + .../remote/inference/cerebras/cerebras.py | 191 +++++++++ .../remote/inference/cerebras/config.py | 32 ++ .../providers/tests/inference/fixtures.py | 17 + .../tests/inference/test_text_inference.py | 2 + llama_stack/templates/cerebras/__init__.py | 7 + llama_stack/templates/cerebras/build.yaml | 17 + llama_stack/templates/cerebras/cerebras.py | 71 ++++ .../templates/cerebras/doc_template.md | 60 +++ llama_stack/templates/cerebras/run.yaml | 63 +++ 19 files changed, 1018 insertions(+), 292 deletions(-) create mode 120000 distributions/cerebras/build.yaml create mode 100644 distributions/cerebras/compose.yaml create mode 120000 distributions/cerebras/run.yaml create mode 100644 docs/source/distributions/self_hosted_distro/cerebras.md create mode 100644 llama_stack/providers/remote/inference/cerebras/__init__.py create mode 100644 llama_stack/providers/remote/inference/cerebras/cerebras.py create mode 100644 llama_stack/providers/remote/inference/cerebras/config.py create mode 100644 llama_stack/templates/cerebras/__init__.py create mode 100644 llama_stack/templates/cerebras/build.yaml create mode 100644 llama_stack/templates/cerebras/cerebras.py create mode 100644 llama_stack/templates/cerebras/doc_template.md create mode 100644 llama_stack/templates/cerebras/run.yaml diff --git a/README.md b/README.md index 8e57292c3..0dfb1306d 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ Additionally, we have designed every element of the Stack such that APIs as well | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Cerebras | Single Node | | :heavy_check_mark: | | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | @@ -95,6 +96,7 @@ Additionally, we have designed every element of the Stack such that APIs as well |:----------------: |:------------------------------------------: |:-----------------------: | | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | +| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) | | Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) | diff --git a/distributions/cerebras/build.yaml b/distributions/cerebras/build.yaml new file mode 120000 index 000000000..bccbbcf60 --- /dev/null +++ b/distributions/cerebras/build.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/cerebras/build.yaml \ No newline at end of file diff --git a/distributions/cerebras/compose.yaml b/distributions/cerebras/compose.yaml new file mode 100644 index 000000000..f2e9a6f42 --- /dev/null +++ b/distributions/cerebras/compose.yaml @@ -0,0 +1,16 @@ +services: + llamastack: + image: llamastack/distribution-cerebras + network_mode: "host" + volumes: + - ~/.llama:/root/.llama + - ./run.yaml:/root/llamastack-run-cerebras.yaml + ports: + - "5000:5000" + entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml" + deploy: + restart_policy: + condition: on-failure + delay: 3s + max_attempts: 5 + window: 60s diff --git a/distributions/cerebras/run.yaml b/distributions/cerebras/run.yaml new file mode 120000 index 000000000..9f9d20b4b --- /dev/null +++ b/distributions/cerebras/run.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/cerebras/run.yaml \ No newline at end of file diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 36426e862..80468cc73 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -1,4 +1,152 @@ { + "tgi": [ + "aiohttp", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "huggingface_hub", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "remote-vllm": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "vllm-gpu": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "vllm", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "meta-reference-quantized-gpu": [ + "accelerate", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "fairscale", + "faiss-cpu", + "fastapi", + "fbgemm-gpu", + "fire", + "httpx", + "lm-format-enforcer", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "torch", + "torchao==0.5.0", + "torchvision", + "tqdm", + "transformers", + "uvicorn", + "zmq", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "meta-reference-gpu": [ + "accelerate", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "fairscale", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "lm-format-enforcer", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "torch", + "torchvision", + "tqdm", + "transformers", + "uvicorn", + "zmq", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], "hf-serverless": [ "aiohttp", "aiosqlite", @@ -54,88 +202,7 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "vllm-gpu": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "vllm", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "remote-vllm": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "openai", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "fireworks": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "fireworks-ai", - "httpx", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "tgi": [ + "ollama": [ "aiohttp", "aiosqlite", "blobfile", @@ -145,10 +212,10 @@ "fastapi", "fire", "httpx", - "huggingface_hub", "matplotlib", "nltk", "numpy", + "ollama", "pandas", "pillow", "psycopg2-binary", @@ -190,100 +257,6 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "meta-reference-gpu": [ - "accelerate", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "fairscale", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "lm-format-enforcer", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "torch", - "torchvision", - "tqdm", - "transformers", - "uvicorn", - "zmq", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "meta-reference-quantized-gpu": [ - "accelerate", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "fairscale", - "faiss-cpu", - "fastapi", - "fbgemm-gpu", - "fire", - "httpx", - "lm-format-enforcer", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "torch", - "torchao==0.5.0", - "torchvision", - "tqdm", - "transformers", - "uvicorn", - "zmq", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "ollama": [ - "aiohttp", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "ollama", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], "hf-endpoint": [ "aiohttp", "aiosqlite", @@ -311,5 +284,58 @@ "uvicorn", "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "fireworks": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "fireworks-ai", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "cerebras": [ + "aiosqlite", + "blobfile", + "cerebras_cloud_sdk", + "chardet", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" ] } diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index a45d07ebf..67d39159c 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -66,121 +66,247 @@ llama stack build --list-templates ``` ``` -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| Template Name | Providers | Description | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM | -| | "inference": "remote::hf::serverless", | inference. | -| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| together | { | Use Together.ai for running LLM inference | -| | "inference": "remote::together", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::weaviate" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| fireworks | { | Use Fireworks.ai for running LLM inference | -| | "inference": "remote::fireworks", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::weaviate", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| databricks | { | Use Databricks for running LLM inference | -| | "inference": "remote::databricks", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| vllm | { | Like local, but use vLLM for running LLM inference | -| | "inference": "vllm", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| tgi | { | Use TGI for running LLM inference | -| | "inference": "remote::tgi", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| bedrock | { | Use Amazon Bedrock APIs. | -| | "inference": "remote::bedrock", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| meta-reference-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs | -| | "inference": "meta-reference", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| meta-reference-quantized-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs | -| | "inference": "meta-reference-quantized", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| ollama | { | Use ollama for running LLM inference | -| | "inference": "remote::ollama", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. | -| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| Template Name | Providers | Description | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| tgi | { | Use (an external) TGI server for running LLM inference | +| | "inference": [ | | +| | "remote::tgi" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| remote-vllm | { | Use (an external) vLLM server for running LLM inference | +| | "inference": [ | | +| | "remote::vllm" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| vllm-gpu | { | Use a built-in vLLM engine for running LLM inference | +| | "inference": [ | | +| | "inline::vllm" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| meta-reference-quantized-gpu | { | Use Meta Reference with fp8, int4 quantization for running LLM inference | +| | "inference": [ | | +| | "inline::meta-reference-quantized" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| meta-reference-gpu | { | Use Meta Reference for running LLM inference | +| | "inference": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| hf-serverless | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference | +| | "inference": [ | | +| | "remote::hf::serverless" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| together | { | Use Together.AI for running LLM inference | +| | "inference": [ | | +| | "remote::together" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| ollama | { | Use (an external) Ollama server for running LLM inference | +| | "inference": [ | | +| | "remote::ollama" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| bedrock | { | Use AWS Bedrock for running LLM inference and safety | +| | "inference": [ | | +| | "remote::bedrock" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "remote::bedrock" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| hf-endpoint | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference | +| | "inference": [ | | +| | "remote::hf::endpoint" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| fireworks | { | Use Fireworks.AI for running LLM inference | +| | "inference": [ | | +| | "remote::fireworks" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| cerebras | { | Use Cerebras for running LLM inference | +| | "inference": [ | | +| | "remote::cerebras" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "memory": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ ``` You may then pick a template to build your distribution with providers fitted to your liking. diff --git a/docs/source/distributions/self_hosted_distro/cerebras.md b/docs/source/distributions/self_hosted_distro/cerebras.md new file mode 100644 index 000000000..08b35809a --- /dev/null +++ b/docs/source/distributions/self_hosted_distro/cerebras.md @@ -0,0 +1,61 @@ +# Cerebras Distribution + +The `llamastack/distribution-cerebras` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| inference | `remote::cerebras` | +| memory | `inline::meta-reference` | +| safety | `inline::llama-guard` | +| telemetry | `inline::meta-reference` | + + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) +- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``) + +### Models + +The following models are available by default: + +- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)` +- `meta-llama/Llama-3.1-70B-Instruct (llama3.1-70b)` + + +### Prerequisite: API Keys + +Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/). + + +## Running Llama Stack with Cerebras + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-cerebras \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template cerebras --image-type conda +llama stack run ./run.yaml \ + --port 5001 \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` diff --git a/docs/source/index.md b/docs/source/index.md index 291237843..abfaf51b4 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -45,6 +45,7 @@ Llama Stack already has a number of "adapters" available for some popular Infere | **API Provider** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | Single Node | Y | Y | Y | Y | Y | +| Cerebras | Single Node | | Y | | | | | Fireworks | Hosted | Y | Y | Y | | | | AWS Bedrock | Hosted | | Y | | Y | | | Together | Hosted | Y | Y | | Y | | diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index c8d061f6c..13d463ad8 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.sample.SampleConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="cerebras", + pip_packages=[ + "cerebras_cloud_sdk", + ], + module="llama_stack.providers.remote.inference.cerebras", + config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/cerebras/__init__.py b/llama_stack/providers/remote/inference/cerebras/__init__.py new file mode 100644 index 000000000..a24bb2c70 --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import CerebrasImplConfig + + +async def get_adapter_impl(config: CerebrasImplConfig, _deps): + from .cerebras import CerebrasInferenceAdapter + + assert isinstance( + config, CerebrasImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = CerebrasInferenceAdapter(config) + + await impl.initialize() + + return impl diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py new file mode 100644 index 000000000..65022f85e --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator + +from cerebras.cloud.sdk import AsyncCerebras + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.apis.inference import * # noqa: F403 + +from llama_models.datatypes import CoreModelId + +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, + completion_request_to_prompt, +) + +from .config import CerebrasImplConfig + + +model_aliases = [ + build_model_alias( + "llama3.1-8b", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama3.1-70b", + CoreModelId.llama3_1_70b_instruct.value, + ), +] + + +class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: CerebrasImplConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=model_aliases, + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + self.client = AsyncCerebras( + base_url=self.config.base_url, api_key=self.config.api_key + ) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model_id: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion( + request, + ) + else: + return await self._nonstream_completion(request) + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = self._get_params(request) + + r = await self.client.completions.create(**params) + + return process_completion_response(r, self.formatter) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params(request) + + stream = await self.client.completions.create(**params) + + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = self._get_params(request) + + r = await self.client.completions.create(**params) + + return process_chat_completion_response(r, self.formatter) + + async def _stream_chat_completion( + self, request: CompletionRequest + ) -> AsyncGenerator: + params = self._get_params(request) + + stream = await self.client.completions.create(**params) + + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + if request.sampling_params and request.sampling_params.top_k: + raise ValueError("`top_k` not supported by Cerebras") + + prompt = "" + if type(request) == ChatCompletionRequest: + prompt = chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ) + elif type(request) == CompletionRequest: + prompt = completion_request_to_prompt(request, self.formatter) + else: + raise ValueError(f"Unknown request type {type(request)}") + + return { + "model": request.model, + "prompt": prompt, + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py new file mode 100644 index 000000000..9bae6ca4d --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any, Dict, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + +DEFAULT_BASE_URL = "https://api.cerebras.ai" + + +@json_schema_type +class CerebrasImplConfig(BaseModel): + base_url: str = Field( + default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), + description="Base URL for the Cerebras API", + ) + api_key: Optional[str] = Field( + default=os.environ.get("CEREBRAS_API_KEY"), + description="Cerebras API Key", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "base_url": DEFAULT_BASE_URL, + "api_key": "${env.CEREBRAS_API_KEY}", + } diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index a427eef12..21e122149 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( ) from llama_stack.providers.remote.inference.bedrock import BedrockConfig +from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig @@ -64,6 +65,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_cerebras() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="cerebras", + provider_type="remote::cerebras", + config=CerebrasImplConfig( + api_key=get_env_or_fail("CEREBRAS_API_KEY"), + ).model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( @@ -206,6 +222,7 @@ INFERENCE_FIXTURES = [ "vllm_remote", "remote", "bedrock", + "cerebras", "nvidia", "tgi", ] diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 9e5c67375..aa2f0b413 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -94,6 +94,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::cerebras", ): pytest.skip("Other inference providers don't support completion() yet") @@ -139,6 +140,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::cerebras", ): pytest.skip( "Other inference providers don't support structured output in completions yet" diff --git a/llama_stack/templates/cerebras/__init__.py b/llama_stack/templates/cerebras/__init__.py new file mode 100644 index 000000000..9f9929b52 --- /dev/null +++ b/llama_stack/templates/cerebras/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .cerebras import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml new file mode 100644 index 000000000..a1fe93099 --- /dev/null +++ b/llama_stack/templates/cerebras/build.yaml @@ -0,0 +1,17 @@ +version: '2' +name: cerebras +distribution_spec: + description: Use Cerebras for running LLM inference + docker_image: null + providers: + inference: + - remote::cerebras + safety: + - inline::llama-guard + memory: + - inline::meta-reference + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py new file mode 100644 index 000000000..58e05adf8 --- /dev/null +++ b/llama_stack/templates/cerebras/cerebras.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_models.sku_list import all_registered_models + +from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig +from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases + +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::cerebras"], + "safety": ["inline::llama-guard"], + "memory": ["inline::meta-reference"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + } + + inference_provider = Provider( + provider_id="cerebras", + provider_type="remote::cerebras", + config=CerebrasImplConfig.sample_run_config(), + ) + + core_model_to_hf_repo = { + m.descriptor(): m.huggingface_repo for m in all_registered_models() + } + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model], + provider_model_id=m.provider_model_id, + ) + for m in model_aliases + ] + + return DistributionTemplate( + name="cerebras", + distro_type="self_hosted", + description="Use Cerebras for running LLM inference", + docker_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + default_models=default_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + }, + default_models=default_models, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "CEREBRAS_API_KEY": ( + "", + "Cerebras API Key", + ), + }, + ) diff --git a/llama_stack/templates/cerebras/doc_template.md b/llama_stack/templates/cerebras/doc_template.md new file mode 100644 index 000000000..77fc6f478 --- /dev/null +++ b/llama_stack/templates/cerebras/doc_template.md @@ -0,0 +1,60 @@ +# Cerebras Distribution + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} ({{ model.provider_model_id }})` +{% endfor %} +{% endif %} + + +### Prerequisite: API Keys + +Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/). + + +## Running Llama Stack with Cerebras + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-{{ name }} \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template cerebras --image-type conda +llama stack run ./run.yaml \ + --port 5001 \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml new file mode 100644 index 000000000..0b41f5b76 --- /dev/null +++ b/llama_stack/templates/cerebras/run.yaml @@ -0,0 +1,63 @@ +version: '2' +image_name: cerebras +docker_image: null +conda_env: cerebras +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: cerebras + provider_type: remote::cerebras + config: + base_url: https://api.cerebras.ai + api_key: ${env.CEREBRAS_API_KEY} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + memory: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/faiss_store.db + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: null + provider_model_id: llama3.1-8b +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: null + provider_model_id: llama3.1-70b +shields: +- params: null + shield_id: meta-llama/Llama-Guard-3-8B + provider_id: null + provider_shield_id: null +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] From caf1dac1145193846c0c77a93af3c4669dc5575d Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Tue, 3 Dec 2024 21:18:30 -0800 Subject: [PATCH 3/4] unregister API for dataset (#507) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? 1) Implement `unregister_dataset(dataset_id)` API in both llama stack routing table and providers: It removes {dataset_id -> Dataset} mapping from routing table and removes the dataset_id references in provider as well (ex. for huggingface, we use a KV store to store the dataset id => dataset. we delete it during unregistering as well) 2) expose the datasets/unregister_dataset api endpoint ## Test Plan **Unit test:** ` pytest llama_stack/providers/tests/datasetio/test_datasetio.py -m "huggingface" -v -s --tb=short --disable-warnings ` **Test on endpoint:** tested llama stack using an ollama distribution template: 1) start an ollama server 2) Start a llama stack server with the default ollama distribution config + dataset/datasetsio APIs + datasetio provider ``` ---- .../ollama-run.yaml ... apis: - agents - inference - memory - safety - telemetry - datasetio - datasets providers: datasetio: - provider_id: localfs provider_type: inline::localfs config: {} ... ``` saw that the new API showed up in startup script ``` Serving API datasets GET /alpha/datasets/get GET /alpha/datasets/list POST /alpha/datasets/register POST /alpha/datasets/unregister ``` 3) query `/alpha/datasets/unregister` through curl (since we have not implemented unregister api in llama stack client) ``` (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register --dataset-id sixian --url https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst --schema {} (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“ โ”ƒ identifier โ”ƒ provider_id โ”ƒ metadata โ”ƒ type โ”ƒ โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ โ”‚ sixian โ”‚ localfs โ”‚ {} โ”‚ dataset โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets register --dataset-id sixian2 --url https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst --schema {} (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“ โ”ƒ identifier โ”ƒ provider_id โ”ƒ metadata โ”ƒ type โ”ƒ โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ โ”‚ sixian โ”‚ localfs โ”‚ {} โ”‚ dataset โ”‚ โ”‚ sixian2 โ”‚ localfs โ”‚ {} โ”‚ dataset โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ (base) sxyi@sxyi-mbp llama-stack % curl http://localhost:5001/alpha/datasets/unregister \ -H "Content-Type: application/json" \ -d '{"dataset_id": "sixian"}' null% (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“ โ”ƒ identifier โ”ƒ provider_id โ”ƒ metadata โ”ƒ type โ”ƒ โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ โ”‚ sixian2 โ”‚ localfs โ”‚ {} โ”‚ dataset โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ (base) sxyi@sxyi-mbp llama-stack % curl http://localhost:5001/alpha/datasets/unregister \ -H "Content-Type: application/json" \ -d '{"dataset_id": "sixian2"}' null% (base) sxyi@sxyi-mbp llama-stack % llama-stack-client datasets list ``` ## Sources ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- docs/resources/llama-stack-spec.html | 50 +++++++++++++++++++ docs/resources/llama-stack-spec.yaml | 33 ++++++++++++ llama_stack/apis/datasets/client.py | 15 ++++++ llama_stack/apis/datasets/datasets.py | 6 +++ .../distribution/routers/routing_tables.py | 8 +++ llama_stack/providers/datatypes.py | 2 + .../inline/datasetio/localfs/datasetio.py | 3 ++ .../datasetio/huggingface/huggingface.py | 5 ++ .../tests/datasetio/test_datasetio.py | 12 +++++ 9 files changed, 134 insertions(+) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 090253804..4f220ea1e 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -2291,6 +2291,39 @@ "required": true } } + }, + "/alpha/datasets/unregister": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "Datasets" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnregisterDatasetRequest" + } + } + }, + "required": true + } + } } }, "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", @@ -7917,6 +7950,18 @@ "required": [ "model_id" ] + }, + "UnregisterDatasetRequest": { + "type": "object", + "properties": { + "dataset_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "dataset_id" + ] } }, "responses": {} @@ -8529,6 +8574,10 @@ "name": "UnregisterModelRequest", "description": "" }, + { + "name": "UnregisterDatasetRequest", + "description": "" + }, { "name": "UnstructuredLogEvent", "description": "" @@ -8718,6 +8767,7 @@ "URL", "UnregisterMemoryBankRequest", "UnregisterModelRequest", + "UnregisterDatasetRequest", "UnstructuredLogEvent", "UserMessage", "VectorMemoryBank", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 8ffd9fdef..6564ddf3f 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -3253,6 +3253,14 @@ components: required: - model_id type: object + UnregisterDatasetRequest: + additionalProperties: false + properties: + dataset_id: + type: string + required: + - dataset_id + type: object UnstructuredLogEvent: additionalProperties: false properties: @@ -3789,6 +3797,27 @@ paths: description: OK tags: - Datasets + /alpha/datasets/unregister: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UnregisterDatasetRequest' + required: true + responses: + '200': + description: OK + tags: + - Datasets /alpha/eval-tasks/get: get: parameters: @@ -5242,6 +5271,9 @@ tags: - description: name: UnregisterModelRequest +- description: + name: UnregisterDatasetRequest - description: name: UnstructuredLogEvent @@ -5418,6 +5450,7 @@ x-tagGroups: - URL - UnregisterMemoryBankRequest - UnregisterModelRequest + - UnregisterDatasetRequest - UnstructuredLogEvent - UserMessage - VectorMemoryBank diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py index 9e5891e74..c379a49fb 100644 --- a/llama_stack/apis/datasets/client.py +++ b/llama_stack/apis/datasets/client.py @@ -78,6 +78,21 @@ class DatasetsClient(Datasets): return [DatasetDefWithProvider(**x) for x in response.json()] + async def unregister_dataset( + self, + dataset_id: str, + ) -> None: + async with httpx.AsyncClient() as client: + response = await client.delete( + f"{self.base_url}/datasets/unregister", + params={ + "dataset_id": dataset_id, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + async def run_main(host: str, port: int): client = DatasetsClient(f"http://{host}:{port}") diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 2ab958782..e1ac4af21 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -64,3 +64,9 @@ class Datasets(Protocol): @webmethod(route="/datasets/list", method="GET") async def list_datasets(self) -> List[Dataset]: ... + + @webmethod(route="/datasets/unregister", method="POST") + async def unregister_dataset( + self, + dataset_id: str, + ) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4df693b26..2fb5a5e1c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_memory_bank(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) + elif api == Api.datasetio: + return await p.unregister_dataset(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") @@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): ) await self.register_object(dataset) + async def unregister_dataset(self, dataset_id: str) -> None: + dataset = await self.get_dataset(dataset_id) + if dataset is None: + raise ValueError(f"Dataset {dataset_id} not found") + await self.unregister_object(dataset) + class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFn]: diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 080204e45..8e89bcc72 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol): async def register_dataset(self, dataset: Dataset) -> None: ... + async def unregister_dataset(self, dataset_id: str) -> None: ... + class ScoringFunctionsProtocolPrivate(Protocol): async def list_scoring_functions(self) -> List[ScoringFn]: ... diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 4de1850ae..010610056 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_impl=dataset_impl, ) + async def unregister_dataset(self, dataset_id: str) -> None: + del self.dataset_infos[dataset_id] + async def get_rows_paginated( self, dataset_id: str, diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index c2e4506bf..cdd5d9cd3 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -64,6 +64,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): ) self.dataset_infos[dataset_def.identifier] = dataset_def + async def unregister_dataset(self, dataset_id: str) -> None: + key = f"{DATASETS_PREFIX}{dataset_id}" + await self.kvstore.delete(key=key) + del self.dataset_infos[dataset_id] + async def get_rows_paginated( self, dataset_id: str, diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index dd2cbd019..7d88b6115 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -81,6 +81,18 @@ class TestDatasetIO: assert len(response) == 1 assert response[0].identifier == "test_dataset" + with pytest.raises(Exception) as exc_info: + # unregister a dataset that does not exist + await datasets_impl.unregister_dataset("test_dataset2") + + await datasets_impl.unregister_dataset("test_dataset") + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 0 + + with pytest.raises(Exception) as exc_info: + await datasets_impl.unregister_dataset("test_dataset") + @pytest.mark.asyncio async def test_get_rows_paginated(self, datasetio_stack): datasetio_impl, datasets_impl = datasetio_stack From 16769256b7d1f7ffadc09480eb2c8e1367fc2c8b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 4 Dec 2024 09:47:09 -0800 Subject: [PATCH 4/4] [llama stack ui] add native eval & inspect distro & playground pages (#541) # What does this PR do? New Pages Added: - (1) Inspect Distro - (2) Evaluations: - (a) native evaluations (including generation) - (b) application evaluations (no generation, scoring only) - (3) Playground: - (a) chat - (b) RAG ## Test Plan ``` streamlit run app.py ``` #### Playground https://github.com/user-attachments/assets/6ca617e8-32ca-49b2-9774-185020ff5204 #### Inspect https://github.com/user-attachments/assets/01d52b2d-92af-4e3a-b623-a9b8ba22ba99 #### Evaluations (Generation + Scoring) https://github.com/user-attachments/assets/345845c7-2a2b-4095-960a-9ae40f6a93cf #### Evaluations (Scoring) https://github.com/user-attachments/assets/6cc1659f-eba4-49ca-a0a5-7c243557b4f5 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/distribution/ui/README.md | 6 + llama_stack/distribution/ui/app.py | 196 +++---------- .../distribution/ui/modules/__init__.py | 5 + llama_stack/distribution/ui/modules/api.py | 13 +- llama_stack/distribution/ui/modules/utils.py | 11 + llama_stack/distribution/ui/page/__init__.py | 5 + .../ui/page/distribution/datasets.py | 19 ++ .../ui/page/distribution/eval_tasks.py | 22 ++ .../ui/page/distribution/memory_banks.py | 23 ++ .../ui/page/distribution/models.py | 19 ++ .../ui/page/distribution/providers.py | 20 ++ .../ui/page/distribution/resources.py | 52 ++++ .../ui/page/distribution/scoring_functions.py | 22 ++ .../ui/page/distribution/shields.py | 20 ++ .../ui/page/evaluations/__init__.py | 5 + .../ui/page/evaluations/app_eval.py | 148 ++++++++++ .../ui/page/evaluations/native_eval.py | 257 ++++++++++++++++++ .../ui/page/playground/__init__.py | 5 + .../distribution/ui/page/playground/chat.py | 123 +++++++++ .../distribution/ui/page/playground/rag.py | 188 +++++++++++++ llama_stack/distribution/ui/requirements.txt | 1 + .../scoring_fn/fn_defs/llm_as_judge_base.py | 6 +- 22 files changed, 1000 insertions(+), 166 deletions(-) create mode 100644 llama_stack/distribution/ui/modules/__init__.py create mode 100644 llama_stack/distribution/ui/page/__init__.py create mode 100644 llama_stack/distribution/ui/page/distribution/datasets.py create mode 100644 llama_stack/distribution/ui/page/distribution/eval_tasks.py create mode 100644 llama_stack/distribution/ui/page/distribution/memory_banks.py create mode 100644 llama_stack/distribution/ui/page/distribution/models.py create mode 100644 llama_stack/distribution/ui/page/distribution/providers.py create mode 100644 llama_stack/distribution/ui/page/distribution/resources.py create mode 100644 llama_stack/distribution/ui/page/distribution/scoring_functions.py create mode 100644 llama_stack/distribution/ui/page/distribution/shields.py create mode 100644 llama_stack/distribution/ui/page/evaluations/__init__.py create mode 100644 llama_stack/distribution/ui/page/evaluations/app_eval.py create mode 100644 llama_stack/distribution/ui/page/evaluations/native_eval.py create mode 100644 llama_stack/distribution/ui/page/playground/__init__.py create mode 100644 llama_stack/distribution/ui/page/playground/chat.py create mode 100644 llama_stack/distribution/ui/page/playground/rag.py diff --git a/llama_stack/distribution/ui/README.md b/llama_stack/distribution/ui/README.md index a91883067..2cc352c52 100644 --- a/llama_stack/distribution/ui/README.md +++ b/llama_stack/distribution/ui/README.md @@ -2,6 +2,12 @@ [!NOTE] This is a work in progress. +## Prerequisite +- Start up Llama Stack Server +``` +llama stack run +``` + ## Running Streamlit App ``` diff --git a/llama_stack/distribution/ui/app.py b/llama_stack/distribution/ui/app.py index 763b126a7..87a80e235 100644 --- a/llama_stack/distribution/ui/app.py +++ b/llama_stack/distribution/ui/app.py @@ -3,170 +3,54 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -import json - -import pandas as pd - import streamlit as st -from modules.api import LlamaStackEvaluation - -from modules.utils import process_dataset - -EVALUATION_API = LlamaStackEvaluation() - def main(): - # Add collapsible sidebar - with st.sidebar: - # Add collapse button - if "sidebar_state" not in st.session_state: - st.session_state.sidebar_state = True - - if st.session_state.sidebar_state: - st.title("Navigation") - page = st.radio( - "Select a Page", - ["Application Evaluation"], - index=0, - ) - else: - page = "Application Evaluation" # Default page when sidebar is collapsed - - # Main content area - st.title("๐Ÿฆ™ Llama Stack Evaluations") - - if page == "Application Evaluation": - application_evaluation_page() - - -def application_evaluation_page(): - # File uploader - uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"]) - - if uploaded_file is None: - st.error("No file uploaded") - return - - # Process uploaded file - df = process_dataset(uploaded_file) - if df is None: - st.error("Error processing file") - return - - # Display dataset information - st.success("Dataset loaded successfully!") - - # Display dataframe preview - st.subheader("Dataset Preview") - st.dataframe(df) - - # Select Scoring Functions to Run Evaluation On - st.subheader("Select Scoring Functions") - scoring_functions = EVALUATION_API.list_scoring_functions() - scoring_functions = {sf.identifier: sf for sf in scoring_functions} - scoring_functions_names = list(scoring_functions.keys()) - selected_scoring_functions = st.multiselect( - "Choose one or more scoring functions", - options=scoring_functions_names, - help="Choose one or more scoring functions.", + # Evaluation pages + application_evaluation_page = st.Page( + "page/evaluations/app_eval.py", + title="Evaluations (Scoring)", + icon="๐Ÿ“Š", + default=False, + ) + native_evaluation_page = st.Page( + "page/evaluations/native_eval.py", + title="Evaluations (Generation + Scoring)", + icon="๐Ÿ“Š", + default=False, ) - available_models = EVALUATION_API.list_models() - available_models = [m.identifier for m in available_models] + # Playground pages + chat_page = st.Page( + "page/playground/chat.py", title="Chat", icon="๐Ÿ’ฌ", default=True + ) + rag_page = st.Page("page/playground/rag.py", title="RAG", icon="๐Ÿ’ฌ", default=False) - scoring_params = {} - if selected_scoring_functions: - st.write("Selected:") - for scoring_fn_id in selected_scoring_functions: - scoring_fn = scoring_functions[scoring_fn_id] - st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}") - new_params = None - if scoring_fn.params: - new_params = {} - for param_name, param_value in scoring_fn.params.to_dict().items(): - if param_name == "type": - new_params[param_name] = param_value - continue + # Distribution pages + resources_page = st.Page( + "page/distribution/resources.py", title="Resources", icon="๐Ÿ”", default=False + ) + provider_page = st.Page( + "page/distribution/providers.py", + title="API Providers", + icon="๐Ÿ”", + default=False, + ) - if param_name == "judge_model": - value = st.selectbox( - f"Select **{param_name}** for {scoring_fn_id}", - options=available_models, - index=0, - key=f"{scoring_fn_id}_{param_name}", - ) - new_params[param_name] = value - else: - value = st.text_area( - f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format", - value=json.dumps(param_value, indent=2), - height=80, - ) - try: - new_params[param_name] = json.loads(value) - except json.JSONDecodeError: - st.error( - f"Invalid JSON for **{param_name}** in {scoring_fn_id}" - ) - - st.json(new_params) - scoring_params[scoring_fn_id] = new_params - - # Add run evaluation button & slider - total_rows = len(df) - num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows) - - if st.button("Run Evaluation"): - progress_text = "Running evaluation..." - progress_bar = st.progress(0, text=progress_text) - rows = df.to_dict(orient="records") - if num_rows < total_rows: - rows = rows[:num_rows] - - # Create separate containers for progress text and results - progress_text_container = st.empty() - results_container = st.empty() - output_res = {} - for i, r in enumerate(rows): - # Update progress - progress = i / len(rows) - progress_bar.progress(progress, text=progress_text) - - # Run evaluation for current row - score_res = EVALUATION_API.run_scoring( - r, - scoring_function_ids=selected_scoring_functions, - scoring_params=scoring_params, - ) - - for k in r.keys(): - if k not in output_res: - output_res[k] = [] - output_res[k].append(r[k]) - - for fn_id in selected_scoring_functions: - if fn_id not in output_res: - output_res[fn_id] = [] - output_res[fn_id].append(score_res.results[fn_id].score_rows[0]) - - # Display current row results using separate containers - progress_text_container.write( - f"Expand to see current processed result ({i+1}/{len(rows)})" - ) - results_container.json( - score_res.to_json(), - expanded=2, - ) - - progress_bar.progress(1.0, text="Evaluation complete!") - - # Display results in dataframe - if output_res: - output_df = pd.DataFrame(output_res) - st.subheader("Evaluation Results") - st.dataframe(output_df) + pg = st.navigation( + { + "Playground": [ + chat_page, + rag_page, + application_evaluation_page, + native_evaluation_page, + ], + "Inspect": [provider_page, resources_page], + }, + expanded=False, + ) + pg.run() if __name__ == "__main__": diff --git a/llama_stack/distribution/ui/modules/__init__.py b/llama_stack/distribution/ui/modules/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/ui/modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index a8d8bf37d..d3852caee 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -11,7 +11,7 @@ from typing import Optional from llama_stack_client import LlamaStackClient -class LlamaStackEvaluation: +class LlamaStackApi: def __init__(self): self.client = LlamaStackClient( base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:5000"), @@ -22,14 +22,6 @@ class LlamaStackEvaluation: }, ) - def list_scoring_functions(self): - """List all available scoring functions""" - return self.client.scoring_functions.list() - - def list_models(self): - """List all available judge models""" - return self.client.models.list() - def run_scoring( self, row, scoring_function_ids: list[str], scoring_params: Optional[dict] ): @@ -39,3 +31,6 @@ class LlamaStackEvaluation: return self.client.scoring.score( input_rows=[row], scoring_functions=scoring_params ) + + +llama_stack_api = LlamaStackApi() diff --git a/llama_stack/distribution/ui/modules/utils.py b/llama_stack/distribution/ui/modules/utils.py index f8da2e54e..67cce98fa 100644 --- a/llama_stack/distribution/ui/modules/utils.py +++ b/llama_stack/distribution/ui/modules/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 import os import pandas as pd @@ -29,3 +30,13 @@ def process_dataset(file): except Exception as e: st.error(f"Error processing file: {str(e)}") return None + + +def data_url_from_file(file) -> str: + file_content = file.getvalue() + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type = file.type + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url diff --git a/llama_stack/distribution/ui/page/__init__.py b/llama_stack/distribution/ui/page/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/ui/page/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/ui/page/distribution/datasets.py b/llama_stack/distribution/ui/page/distribution/datasets.py new file mode 100644 index 000000000..44e314cde --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/datasets.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def datasets(): + st.header("Datasets") + + datasets_info = { + d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list() + } + + selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys())) + st.json(datasets_info[selected_dataset], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/distribution/ui/page/distribution/eval_tasks.py new file mode 100644 index 000000000..4957fb178 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/eval_tasks.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def eval_tasks(): + # Eval Tasks Section + st.header("Eval Tasks") + + eval_tasks_info = { + d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list() + } + + selected_eval_task = st.selectbox( + "Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect" + ) + st.json(eval_tasks_info[selected_eval_task], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/memory_banks.py b/llama_stack/distribution/ui/page/distribution/memory_banks.py new file mode 100644 index 000000000..f28010bf2 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/memory_banks.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def memory_banks(): + st.header("Memory Banks") + memory_banks_info = { + m.identifier: m.to_dict() for m in llama_stack_api.client.memory_banks.list() + } + + if len(memory_banks_info) > 0: + selected_memory_bank = st.selectbox( + "Select a memory bank", list(memory_banks_info.keys()) + ) + st.json(memory_banks_info[selected_memory_bank]) + else: + st.info("No memory banks found") diff --git a/llama_stack/distribution/ui/page/distribution/models.py b/llama_stack/distribution/ui/page/distribution/models.py new file mode 100644 index 000000000..70b166f2e --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/models.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def models(): + # Models Section + st.header("Models") + models_info = { + m.identifier: m.to_dict() for m in llama_stack_api.client.models.list() + } + + selected_model = st.selectbox("Select a model", list(models_info.keys())) + st.json(models_info[selected_model]) diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py new file mode 100644 index 000000000..69f6bd771 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def providers(): + st.header("๐Ÿ” API Providers") + apis_providers_info = llama_stack_api.client.providers.list() + # selected_api = st.selectbox("Select an API", list(apis_providers_info.keys())) + for api in apis_providers_info.keys(): + st.markdown(f"###### {api}") + st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500) + + +providers() diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/distribution/ui/page/distribution/resources.py new file mode 100644 index 000000000..6b3ea0e3a --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/resources.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from page.distribution.datasets import datasets +from page.distribution.eval_tasks import eval_tasks +from page.distribution.memory_banks import memory_banks +from page.distribution.models import models +from page.distribution.scoring_functions import scoring_functions +from page.distribution.shields import shields + +from streamlit_option_menu import option_menu + + +def resources_page(): + options = [ + "Models", + "Memory Banks", + "Shields", + "Scoring Functions", + "Datasets", + "Eval Tasks", + ] + icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"] + selected_resource = option_menu( + None, + options, + icons=icons, + orientation="horizontal", + styles={ + "nav-link": { + "font-size": "12px", + }, + }, + ) + if selected_resource == "Eval Tasks": + eval_tasks() + elif selected_resource == "Memory Banks": + memory_banks() + elif selected_resource == "Datasets": + datasets() + elif selected_resource == "Models": + models() + elif selected_resource == "Scoring Functions": + scoring_functions() + elif selected_resource == "Shields": + shields() + + +resources_page() diff --git a/llama_stack/distribution/ui/page/distribution/scoring_functions.py b/llama_stack/distribution/ui/page/distribution/scoring_functions.py new file mode 100644 index 000000000..581ae0db7 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/scoring_functions.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def scoring_functions(): + st.header("Scoring Functions") + + scoring_functions_info = { + s.identifier: s.to_dict() + for s in llama_stack_api.client.scoring_functions.list() + } + + selected_scoring_function = st.selectbox( + "Select a scoring function", list(scoring_functions_info.keys()) + ) + st.json(scoring_functions_info[selected_scoring_function], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/shields.py b/llama_stack/distribution/ui/page/distribution/shields.py new file mode 100644 index 000000000..18bbfc008 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/shields.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def shields(): + # Shields Section + st.header("Shields") + + shields_info = { + s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list() + } + + selected_shield = st.selectbox("Select a shield", list(shields_info.keys())) + st.json(shields_info[selected_shield]) diff --git a/llama_stack/distribution/ui/page/evaluations/__init__.py b/llama_stack/distribution/ui/page/evaluations/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/ui/page/evaluations/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/distribution/ui/page/evaluations/app_eval.py new file mode 100644 index 000000000..5ec47ed45 --- /dev/null +++ b/llama_stack/distribution/ui/page/evaluations/app_eval.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +import pandas as pd +import streamlit as st + +from modules.api import llama_stack_api +from modules.utils import process_dataset + + +def application_evaluation_page(): + + st.set_page_config(page_title="Evaluations (Scoring)", page_icon="๐Ÿฆ™") + st.title("๐Ÿ“Š Evaluations (Scoring)") + + # File uploader + uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"]) + + if uploaded_file is None: + st.error("No file uploaded") + return + + # Process uploaded file + df = process_dataset(uploaded_file) + if df is None: + st.error("Error processing file") + return + + # Display dataset information + st.success("Dataset loaded successfully!") + + # Display dataframe preview + st.subheader("Dataset Preview") + st.dataframe(df) + + # Select Scoring Functions to Run Evaluation On + st.subheader("Select Scoring Functions") + scoring_functions = llama_stack_api.client.scoring_functions.list() + scoring_functions = {sf.identifier: sf for sf in scoring_functions} + scoring_functions_names = list(scoring_functions.keys()) + selected_scoring_functions = st.multiselect( + "Choose one or more scoring functions", + options=scoring_functions_names, + help="Choose one or more scoring functions.", + ) + + available_models = llama_stack_api.client.models.list() + available_models = [m.identifier for m in available_models] + + scoring_params = {} + if selected_scoring_functions: + st.write("Selected:") + for scoring_fn_id in selected_scoring_functions: + scoring_fn = scoring_functions[scoring_fn_id] + st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}") + new_params = None + if scoring_fn.params: + new_params = {} + for param_name, param_value in scoring_fn.params.to_dict().items(): + if param_name == "type": + new_params[param_name] = param_value + continue + + if param_name == "judge_model": + value = st.selectbox( + f"Select **{param_name}** for {scoring_fn_id}", + options=available_models, + index=0, + key=f"{scoring_fn_id}_{param_name}", + ) + new_params[param_name] = value + else: + value = st.text_area( + f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format", + value=json.dumps(param_value, indent=2), + height=80, + ) + try: + new_params[param_name] = json.loads(value) + except json.JSONDecodeError: + st.error( + f"Invalid JSON for **{param_name}** in {scoring_fn_id}" + ) + + st.json(new_params) + scoring_params[scoring_fn_id] = new_params + + # Add run evaluation button & slider + total_rows = len(df) + num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows) + + if st.button("Run Evaluation"): + progress_text = "Running evaluation..." + progress_bar = st.progress(0, text=progress_text) + rows = df.to_dict(orient="records") + if num_rows < total_rows: + rows = rows[:num_rows] + + # Create separate containers for progress text and results + progress_text_container = st.empty() + results_container = st.empty() + output_res = {} + for i, r in enumerate(rows): + # Update progress + progress = i / len(rows) + progress_bar.progress(progress, text=progress_text) + + # Run evaluation for current row + score_res = llama_stack_api.run_scoring( + r, + scoring_function_ids=selected_scoring_functions, + scoring_params=scoring_params, + ) + + for k in r.keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(r[k]) + + for fn_id in selected_scoring_functions: + if fn_id not in output_res: + output_res[fn_id] = [] + output_res[fn_id].append(score_res.results[fn_id].score_rows[0]) + + # Display current row results using separate containers + progress_text_container.write( + f"Expand to see current processed result ({i+1}/{len(rows)})" + ) + results_container.json( + score_res.to_json(), + expanded=2, + ) + + progress_bar.progress(1.0, text="Evaluation complete!") + + # Display results in dataframe + if output_res: + output_df = pd.DataFrame(output_res) + st.subheader("Evaluation Results") + st.dataframe(output_df) + + +application_evaluation_page() diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py new file mode 100644 index 000000000..b8cc8bfa6 --- /dev/null +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +import pandas as pd + +import streamlit as st + +from modules.api import llama_stack_api + + +def select_eval_task_1(): + # Select Eval Tasks + st.subheader("1. Choose An Eval Task") + eval_tasks = llama_stack_api.client.eval_tasks.list() + eval_tasks = {et.identifier: et for et in eval_tasks} + eval_tasks_names = list(eval_tasks.keys()) + selected_eval_task = st.selectbox( + "Choose an eval task.", + options=eval_tasks_names, + help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.", + ) + with st.expander("View Eval Task"): + st.json(eval_tasks[selected_eval_task], expanded=True) + + st.session_state["selected_eval_task"] = selected_eval_task + st.session_state["eval_tasks"] = eval_tasks + if st.button("Confirm", key="confirm_1"): + st.session_state["selected_eval_task_1_next"] = True + + +def define_eval_candidate_2(): + if not st.session_state.get("selected_eval_task_1_next", None): + return + + st.subheader("2. Define Eval Candidate") + st.info( + """ + Define the configurations for the evaluation candidate model or agent used for generation. + Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig. + """ + ) + with st.expander("Define Eval Candidate", expanded=True): + # Define Eval Candidate + candidate_type = st.radio("Candidate Type", ["model", "agent"]) + + available_models = llama_stack_api.client.models.list() + available_models = [model.identifier for model in available_models] + selected_model = st.selectbox( + "Choose a model", + available_models, + index=0, + ) + + # Sampling Parameters + st.markdown("##### Sampling Parameters") + strategy = st.selectbox( + "Strategy", + ["greedy", "top_p", "top_k"], + index=0, + ) + temperature = st.slider( + "Temperature", + min_value=0.0, + max_value=1.0, + value=0.0, + step=0.1, + help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", + ) + top_p = st.slider( + "Top P", + min_value=0.0, + max_value=1.0, + value=0.95, + step=0.1, + ) + max_tokens = st.slider( + "Max Tokens", + min_value=0, + max_value=4096, + value=512, + step=1, + help="The maximum number of tokens to generate", + ) + repetition_penalty = st.slider( + "Repetition Penalty", + min_value=1.0, + max_value=2.0, + value=1.0, + step=0.1, + help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.", + ) + if candidate_type == "model": + eval_candidate = { + "type": "model", + "model": selected_model, + "sampling_params": { + "strategy": strategy, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "repetition_penalty": repetition_penalty, + }, + } + elif candidate_type == "agent": + system_prompt = st.text_area( + "System Prompt", + value="You are a helpful AI assistant.", + help="Initial instructions given to the AI to set its behavior and context", + ) + tools_json = st.text_area( + "Tools Configuration (JSON)", + value=json.dumps( + [ + { + "type": "brave_search", + "engine": "brave", + "api_key": "ENTER_BRAVE_API_KEY_HERE", + } + ] + ), + help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.", + height=200, + ) + try: + tools = json.loads(tools_json) + except json.JSONDecodeError: + st.error("Invalid JSON format for tools configuration") + tools = [] + eval_candidate = { + "type": "agent", + "config": { + "model": selected_model, + "instructions": system_prompt, + "tools": tools, + "tool_choice": "auto", + "tool_prompt_format": "json", + "input_shields": [], + "output_shields": [], + "enable_session_persistence": False, + }, + } + st.session_state["eval_candidate"] = eval_candidate + + if st.button("Confirm", key="confirm_2"): + st.session_state["selected_eval_candidate_2_next"] = True + + +def run_evaluation_3(): + if not st.session_state.get("selected_eval_candidate_2_next", None): + return + + st.subheader("3. Run Evaluation") + # Add info box to explain configurations being used + st.info( + """ + Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button. + """ + ) + selected_eval_task = st.session_state["selected_eval_task"] + eval_tasks = st.session_state["eval_tasks"] + eval_candidate = st.session_state["eval_candidate"] + + dataset_id = eval_tasks[selected_eval_task].dataset_id + rows = llama_stack_api.client.datasetio.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + total_rows = len(rows.rows) + # Add number of examples control + num_rows = st.number_input( + "Number of Examples to Evaluate", + min_value=1, + max_value=total_rows, + value=5, + help="Number of examples from the dataset to evaluate. ", + ) + + eval_task_config = { + "type": "benchmark", + "eval_candidate": eval_candidate, + "scoring_params": {}, + } + + with st.expander("View Evaluation Task", expanded=True): + st.json(eval_tasks[selected_eval_task], expanded=True) + with st.expander("View Evaluation Task Configuration", expanded=True): + st.json(eval_task_config, expanded=True) + + # Add run button and handle evaluation + if st.button("Run Evaluation"): + + progress_text = "Running evaluation..." + progress_bar = st.progress(0, text=progress_text) + rows = rows.rows + if num_rows < total_rows: + rows = rows[:num_rows] + + # Create separate containers for progress text and results + progress_text_container = st.empty() + results_container = st.empty() + output_res = {} + for i, r in enumerate(rows): + # Update progress + progress = i / len(rows) + progress_bar.progress(progress, text=progress_text) + # Run evaluation for current row + eval_res = llama_stack_api.client.eval.evaluate_rows( + task_id=selected_eval_task, + input_rows=[r], + scoring_functions=eval_tasks[selected_eval_task].scoring_functions, + task_config=eval_task_config, + ) + + for k in r.keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(r[k]) + + for k in eval_res.generations[0].keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(eval_res.generations[0][k]) + + for scoring_fn in eval_tasks[selected_eval_task].scoring_functions: + if scoring_fn not in output_res: + output_res[scoring_fn] = [] + output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) + + progress_text_container.write( + f"Expand to see current processed result ({i+1}/{len(rows)})" + ) + results_container.json(eval_res, expanded=2) + + progress_bar.progress(1.0, text="Evaluation complete!") + # Display results in dataframe + if output_res: + output_df = pd.DataFrame(output_res) + st.subheader("Evaluation Results") + st.dataframe(output_df) + + +def native_evaluation_page(): + + st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="๐Ÿฆ™") + st.title("๐Ÿ“Š Evaluations (Generation + Scoring)") + + select_eval_task_1() + define_eval_candidate_2() + run_evaluation_3() + + +native_evaluation_page() diff --git a/llama_stack/distribution/ui/page/playground/__init__.py b/llama_stack/distribution/ui/page/playground/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/ui/page/playground/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py new file mode 100644 index 000000000..157922d3b --- /dev/null +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + +# Sidebar configurations +with st.sidebar: + st.header("Configuration") + available_models = llama_stack_api.client.models.list() + available_models = [model.identifier for model in available_models] + selected_model = st.selectbox( + "Choose a model", + available_models, + index=0, + ) + + temperature = st.slider( + "Temperature", + min_value=0.0, + max_value=1.0, + value=0.0, + step=0.1, + help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", + ) + + top_p = st.slider( + "Top P", + min_value=0.0, + max_value=1.0, + value=0.95, + step=0.1, + ) + + max_tokens = st.slider( + "Max Tokens", + min_value=0, + max_value=4096, + value=512, + step=1, + help="The maximum number of tokens to generate", + ) + + repetition_penalty = st.slider( + "Repetition Penalty", + min_value=1.0, + max_value=2.0, + value=1.0, + step=0.1, + help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.", + ) + + stream = st.checkbox("Stream", value=True) + system_prompt = st.text_area( + "System Prompt", + value="You are a helpful AI assistant.", + help="Initial instructions given to the AI to set its behavior and context", + ) + + # Add clear chat button to sidebar + if st.button("Clear Chat", use_container_width=True): + st.session_state.messages = [] + st.rerun() + + +# Main chat interface +st.title("๐Ÿฆ™ Chat") + + +# Initialize chat history +if "messages" not in st.session_state: + st.session_state.messages = [] + +# Display chat messages +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + +# Chat input +if prompt := st.chat_input("Example: What is Llama Stack?"): + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": prompt}) + + # Display user message + with st.chat_message("user"): + st.markdown(prompt) + + # Display assistant response + with st.chat_message("assistant"): + message_placeholder = st.empty() + full_response = "" + + response = llama_stack_api.client.inference.chat_completion( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + model_id=selected_model, + stream=stream, + sampling_params={ + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "repetition_penalty": repetition_penalty, + }, + ) + + if stream: + for chunk in response: + if chunk.event.event_type == "progress": + full_response += chunk.event.delta + message_placeholder.markdown(full_response + "โ–Œ") + message_placeholder.markdown(full_response) + else: + full_response = response + message_placeholder.markdown(full_response.completion_message.content) + + st.session_state.messages.append( + {"role": "assistant", "content": full_response} + ) diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py new file mode 100644 index 000000000..ffcaf1afd --- /dev/null +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.memory_insert_params import Document + +from modules.api import llama_stack_api +from modules.utils import data_url_from_file + + +def rag_chat_page(): + st.title("๐Ÿฆ™ RAG") + + with st.sidebar: + # File/Directory Upload Section + st.subheader("Upload Documents") + uploaded_files = st.file_uploader( + "Upload file(s) or directory", + accept_multiple_files=True, + type=["txt", "pdf", "doc", "docx"], # Add more file types as needed + ) + # Process uploaded files + if uploaded_files: + st.success(f"Successfully uploaded {len(uploaded_files)} files") + # Add memory bank name input field + memory_bank_name = st.text_input( + "Memory Bank Name", + value="rag_bank", + help="Enter a unique identifier for this memory bank", + ) + if st.button("Create Memory Bank"): + documents = [ + Document( + document_id=uploaded_file.name, + content=data_url_from_file(uploaded_file), + ) + for i, uploaded_file in enumerate(uploaded_files) + ] + + providers = llama_stack_api.client.providers.list() + llama_stack_api.client.memory_banks.register( + memory_bank_id=memory_bank_name, # Use the user-provided name + params={ + "embedding_model": "all-MiniLM-L6-v2", + "chunk_size_in_tokens": 512, + "overlap_size_in_tokens": 64, + }, + provider_id=providers["memory"][0].provider_id, + ) + + # insert documents using the custom bank name + llama_stack_api.client.memory.insert( + bank_id=memory_bank_name, # Use the user-provided name + documents=documents, + ) + st.success("Memory bank created successfully!") + + st.subheader("Configure Agent") + # select memory banks + memory_banks = llama_stack_api.client.memory_banks.list() + memory_banks = [bank.identifier for bank in memory_banks] + selected_memory_banks = st.multiselect( + "Select Memory Banks", + memory_banks, + ) + memory_bank_configs = [ + {"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks + ] + + available_models = llama_stack_api.client.models.list() + available_models = [model.identifier for model in available_models] + selected_model = st.selectbox( + "Choose a model", + available_models, + index=0, + ) + system_prompt = st.text_area( + "System Prompt", + value="You are a helpful assistant. ", + help="Initial instructions given to the AI to set its behavior and context", + ) + temperature = st.slider( + "Temperature", + min_value=0.0, + max_value=1.0, + value=0.0, + step=0.1, + help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", + ) + + top_p = st.slider( + "Top P", + min_value=0.0, + max_value=1.0, + value=0.95, + step=0.1, + ) + + # Add clear chat button to sidebar + if st.button("Clear Chat", use_container_width=True): + st.session_state.messages = [] + st.rerun() + + # Chat Interface + if "messages" not in st.session_state: + st.session_state.messages = [] + + # Display chat history + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + selected_model = llama_stack_api.client.models.list()[0].identifier + + agent_config = AgentConfig( + model=selected_model, + instructions=system_prompt, + sampling_params={ + "strategy": "greedy", + "temperature": temperature, + "top_p": top_p, + }, + tools=[ + { + "type": "memory", + "memory_bank_configs": memory_bank_configs, + "query_generator_config": {"type": "default", "sep": " "}, + "max_tokens_in_context": 4096, + "max_chunks": 10, + } + ], + tool_choice="auto", + tool_prompt_format="json", + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + + agent = Agent(llama_stack_api.client, agent_config) + session_id = agent.create_session("rag-session") + + # Chat input + if prompt := st.chat_input("Ask a question about your documents"): + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": prompt}) + + # Display user message + with st.chat_message("user"): + st.markdown(prompt) + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + session_id=session_id, + ) + + # Display assistant response + with st.chat_message("assistant"): + retrieval_message_placeholder = st.empty() + message_placeholder = st.empty() + full_response = "" + retrieval_response = "" + for log in EventLogger().log(response): + log.print() + if log.role == "memory_retrieval": + retrieval_response += log.content.replace("====", "").strip() + retrieval_message_placeholder.info(retrieval_response) + else: + full_response += log.content + message_placeholder.markdown(full_response + "โ–Œ") + message_placeholder.markdown(full_response) + + st.session_state.messages.append( + {"role": "assistant", "content": full_response} + ) + + +rag_chat_page() diff --git a/llama_stack/distribution/ui/requirements.txt b/llama_stack/distribution/ui/requirements.txt index c03959444..39f2b3d27 100644 --- a/llama_stack/distribution/ui/requirements.txt +++ b/llama_stack/distribution/ui/requirements.txt @@ -1,3 +1,4 @@ streamlit pandas llama-stack-client>=0.0.55 +streamlit-option-menu diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py index b00b9a7db..0b18bac01 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn llm_as_judge_base = ScoringFn( @@ -14,4 +14,8 @@ llm_as_judge_base = ScoringFn( return_type=NumberType(), provider_id="llm-as-judge", provider_resource_id="llm-as-judge-base", + params=LLMAsJudgeScoringFnParams( + judge_model="meta-llama/Llama-3.1-405B-Instruct", + prompt_template="Enter custom LLM as Judge Prompt Template", + ), )