minor import fixes

This commit is contained in:
Hardik Shah 2024-08-26 14:21:35 -07:00
parent dc433f6c90
commit c3708859aa
7 changed files with 16 additions and 11 deletions

View file

@ -49,6 +49,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.inference: providers[Api.inference]["meta-ollama"], Api.inference: providers[Api.inference]["meta-ollama"],
Api.safety: providers[Api.safety]["meta-reference"], Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"], Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
Api.memory: remote_spec(Api.memory),
}, },
), ),
DistributionSpec( DistributionSpec(

View file

@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from typing import List from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from llama_toolchain.safety.api.datatypes import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -8,7 +8,7 @@ from codeshield.cs import CodeShield
from termcolor import cprint from termcolor import cprint
from .base import ShieldResponse, TextShield from .base import ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403
class CodeScannerShield(TextShield): class CodeScannerShield(TextShield):

View file

@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_toolchain.safety.api.datatypes import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403
SAFE_RESPONSE = "safe" SAFE_RESPONSE = "safe"
_INSTANCE = None _INSTANCE = None

View file

@ -14,7 +14,7 @@ from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403
class PromptGuardShield(TextShield): class PromptGuardShield(TextShield):

View file

@ -18,9 +18,11 @@ from llama_models.llama3.api.datatypes import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
from llama_toolchain.inference.meta_reference.inference import get_provider_impl from llama_toolchain.inference.meta_reference.inference import get_provider_impl
@ -221,12 +223,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual( self.assertEqual(
events[-1].event_type, ChatCompletionResponseEventType.complete events[-1].event_type, ChatCompletionResponseEventType.complete
) )
self.assertEqual(events[-1].stop_reason, StopReason.end_of_message) self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn)
# last but one event should be eom with tool call # last but one event should be eom with tool call
self.assertEqual( self.assertEqual(
events[-2].event_type, ChatCompletionResponseEventType.progress events[-2].event_type, ChatCompletionResponseEventType.progress
) )
self.assertEqual(events[-2].stop_reason, StopReason.end_of_message) self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn)
self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point")
async def test_multi_turn(self): async def test_multi_turn(self):

View file

@ -14,8 +14,10 @@ from llama_models.llama3.api.datatypes import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType from llama_toolchain.inference.api import (
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.ollama.config import OllamaImplConfig from llama_toolchain.inference.ollama.config import OllamaImplConfig
from llama_toolchain.inference.ollama.ollama import get_provider_impl from llama_toolchain.inference.ollama.ollama import get_provider_impl
@ -62,7 +64,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
iterator = self.api.chat_completion(request) iterator = self.api.chat_completion(request)
async for r in iterator: async for r in iterator:
response = r response = r
print(response.completion_message.content)
self.assertTrue("Paris" in response.completion_message.content) self.assertTrue("Paris" in response.completion_message.content)
self.assertEqual( self.assertEqual(
response.completion_message.stop_reason, StopReason.end_of_turn response.completion_message.stop_reason, StopReason.end_of_turn