From c3708859aa772b05688ee142133dcf4581b66878 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Mon, 26 Aug 2024 14:21:35 -0700 Subject: [PATCH] minor import fixes --- llama_toolchain/distribution/registry.py | 1 + llama_toolchain/safety/meta_reference/shields/base.py | 2 +- .../safety/meta_reference/shields/code_scanner.py | 2 +- .../safety/meta_reference/shields/llama_guard.py | 2 +- .../safety/meta_reference/shields/prompt_guard.py | 2 +- tests/test_inference.py | 10 ++++++---- tests/test_ollama_inference.py | 8 +++++--- 7 files changed, 16 insertions(+), 11 deletions(-) diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 296ee3103..fd9bc6e1b 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -49,6 +49,7 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.inference: providers[Api.inference]["meta-ollama"], Api.safety: providers[Api.safety]["meta-reference"], Api.agentic_system: providers[Api.agentic_system]["meta-reference"], + Api.memory: remote_spec(Api.memory), }, ), DistributionSpec( diff --git a/llama_toolchain/safety/meta_reference/shields/base.py b/llama_toolchain/safety/meta_reference/shields/base.py index 0432b8d3b..ed939212d 100644 --- a/llama_toolchain/safety/meta_reference/shields/base.py +++ b/llama_toolchain/safety/meta_reference/shields/base.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from typing import List from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message -from llama_toolchain.safety.api.datatypes import * # noqa: F403 +from llama_toolchain.safety.api import * # noqa: F403 CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_toolchain/safety/meta_reference/shields/code_scanner.py b/llama_toolchain/safety/meta_reference/shields/code_scanner.py index f78260ff1..564d15a53 100644 --- a/llama_toolchain/safety/meta_reference/shields/code_scanner.py +++ b/llama_toolchain/safety/meta_reference/shields/code_scanner.py @@ -8,7 +8,7 @@ from codeshield.cs import CodeShield from termcolor import cprint from .base import ShieldResponse, TextShield -from llama_toolchain.safety.api.datatypes import * # noqa: F403 +from llama_toolchain.safety.api import * # noqa: F403 class CodeScannerShield(TextShield): diff --git a/llama_toolchain/safety/meta_reference/shields/llama_guard.py b/llama_toolchain/safety/meta_reference/shields/llama_guard.py index a78b8127d..fe04baa00 100644 --- a/llama_toolchain/safety/meta_reference/shields/llama_guard.py +++ b/llama_toolchain/safety/meta_reference/shields/llama_guard.py @@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse -from llama_toolchain.safety.api.datatypes import * # noqa: F403 +from llama_toolchain.safety.api import * # noqa: F403 SAFE_RESPONSE = "safe" _INSTANCE = None diff --git a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py index b9f5dd5a5..a1097a6f7 100644 --- a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py +++ b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py @@ -14,7 +14,7 @@ from termcolor import cprint from transformers import AutoModelForSequenceClassification, AutoTokenizer from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield -from llama_toolchain.safety.api.datatypes import * # noqa: F403 +from llama_toolchain.safety.api import * # noqa: F403 class PromptGuardShield(TextShield): diff --git a/tests/test_inference.py b/tests/test_inference.py index 0a772d26e..277cf7e8a 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -18,9 +18,11 @@ from llama_models.llama3.api.datatypes import ( ToolResponseMessage, UserMessage, ) -from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType -from llama_toolchain.inference.api.endpoints import ChatCompletionRequest +from llama_toolchain.inference.api import ( + ChatCompletionRequest, + ChatCompletionResponseEventType, +) from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig from llama_toolchain.inference.meta_reference.inference import get_provider_impl @@ -221,12 +223,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual( events[-1].event_type, ChatCompletionResponseEventType.complete ) - self.assertEqual(events[-1].stop_reason, StopReason.end_of_message) + self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) # last but one event should be eom with tool call self.assertEqual( events[-2].event_type, ChatCompletionResponseEventType.progress ) - self.assertEqual(events[-2].stop_reason, StopReason.end_of_message) + self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") async def test_multi_turn(self): diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index 8319cab3d..f5b172e69 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -14,8 +14,10 @@ from llama_models.llama3.api.datatypes import ( ToolResponseMessage, UserMessage, ) -from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType -from llama_toolchain.inference.api.endpoints import ChatCompletionRequest +from llama_toolchain.inference.api import ( + ChatCompletionRequest, + ChatCompletionResponseEventType, +) from llama_toolchain.inference.ollama.config import OllamaImplConfig from llama_toolchain.inference.ollama.ollama import get_provider_impl @@ -62,7 +64,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): iterator = self.api.chat_completion(request) async for r in iterator: response = r - + print(response.completion_message.content) self.assertTrue("Paris" in response.completion_message.content) self.assertEqual( response.completion_message.stop_reason, StopReason.end_of_turn