mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
minor import fixes
This commit is contained in:
parent
dc433f6c90
commit
c3708859aa
7 changed files with 16 additions and 11 deletions
|
@ -49,6 +49,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
Api.inference: providers[Api.inference]["meta-ollama"],
|
Api.inference: providers[Api.inference]["meta-ollama"],
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
Api.safety: providers[Api.safety]["meta-reference"],
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||||
|
Api.memory: remote_spec(Api.memory),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
|
|
|
@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from codeshield.cs import CodeShield
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .base import ShieldResponse, TextShield
|
from .base import ShieldResponse, TextShield
|
||||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class CodeScannerShield(TextShield):
|
class CodeScannerShield(TextShield):
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
|
|
||||||
SAFE_RESPONSE = "safe"
|
SAFE_RESPONSE = "safe"
|
||||||
_INSTANCE = None
|
_INSTANCE = None
|
||||||
|
|
|
@ -14,7 +14,7 @@ from termcolor import cprint
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardShield(TextShield):
|
class PromptGuardShield(TextShield):
|
||||||
|
|
|
@ -18,9 +18,11 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
from llama_toolchain.inference.api import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
)
|
||||||
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
|
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
|
||||||
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
|
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
|
||||||
|
|
||||||
|
@ -221,12 +223,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
events[-1].event_type, ChatCompletionResponseEventType.complete
|
events[-1].event_type, ChatCompletionResponseEventType.complete
|
||||||
)
|
)
|
||||||
self.assertEqual(events[-1].stop_reason, StopReason.end_of_message)
|
self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn)
|
||||||
# last but one event should be eom with tool call
|
# last but one event should be eom with tool call
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
events[-2].event_type, ChatCompletionResponseEventType.progress
|
events[-2].event_type, ChatCompletionResponseEventType.progress
|
||||||
)
|
)
|
||||||
self.assertEqual(events[-2].stop_reason, StopReason.end_of_message)
|
self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn)
|
||||||
self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point")
|
self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point")
|
||||||
|
|
||||||
async def test_multi_turn(self):
|
async def test_multi_turn(self):
|
||||||
|
|
|
@ -14,8 +14,10 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
|
from llama_toolchain.inference.api import (
|
||||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
)
|
||||||
from llama_toolchain.inference.ollama.config import OllamaImplConfig
|
from llama_toolchain.inference.ollama.config import OllamaImplConfig
|
||||||
from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
||||||
|
|
||||||
|
@ -62,7 +64,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
iterator = self.api.chat_completion(request)
|
iterator = self.api.chat_completion(request)
|
||||||
async for r in iterator:
|
async for r in iterator:
|
||||||
response = r
|
response = r
|
||||||
|
print(response.completion_message.content)
|
||||||
self.assertTrue("Paris" in response.completion_message.content)
|
self.assertTrue("Paris" in response.completion_message.content)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
response.completion_message.stop_reason, StopReason.end_of_turn
|
response.completion_message.stop_reason, StopReason.end_of_turn
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue