mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 14:59:47 +00:00
Tests pass with Ollama now
This commit is contained in:
parent
a9a041a1de
commit
e51154964f
27 changed files with 83 additions and 65 deletions
|
|
@ -29,11 +29,13 @@ from llama_stack.apis.common.deployment_types import * # noqa: F403
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.common.deployment_types import URL
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Attachment(BaseModel):
|
||||
content: InterleavedTextMedia | URL
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
|
|
@ -102,20 +104,20 @@ class _MemoryBankConfigCommon(BaseModel):
|
|||
|
||||
|
||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
type: Literal["vector"] = "vector"
|
||||
|
||||
|
||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
type: Literal["keyvalue"] = "keyvalue"
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
type: Literal["keyword"] = "keyword"
|
||||
|
||||
|
||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
type: Literal["graph"] = "graph"
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
|
|
@ -230,7 +232,7 @@ class MemoryRetrievalStep(StepCommon):
|
|||
StepType.memory_retrieval.value
|
||||
)
|
||||
memory_bank_ids: List[str]
|
||||
inserted_context: InterleavedTextMedia
|
||||
inserted_context: InterleavedContent
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
@json_schema_type
|
||||
class BatchCompletionRequest(BaseModel):
|
||||
model: str
|
||||
content_batch: List[InterleavedTextMedia]
|
||||
content_batch: List[InterleavedContent]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ class BatchInference(Protocol):
|
|||
async def batch_completion(
|
||||
self,
|
||||
model: str,
|
||||
content_batch: List[InterleavedTextMedia],
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse: ...
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from typing import Optional
|
|||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.deployment_types import URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingMetric(BaseModel):
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@
|
|||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.deployment_types import URL
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.agents import AgentConfig
|
|||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ class CompletionResponseStreamChunk(BaseModel):
|
|||
@json_schema_type
|
||||
class BatchCompletionRequest(BaseModel):
|
||||
model: str
|
||||
content_batch: List[InterleavedTextMedia]
|
||||
content_batch: List[InterleavedContent]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
|
|
|||
|
|
@ -8,27 +8,27 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.common.deployment_types import URL
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.memory_banks import MemoryBank
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankDocument(BaseModel):
|
||||
document_id: str
|
||||
content: InterleavedTextMedia | URL
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
content: InterleavedTextMedia
|
||||
content: InterleavedContent
|
||||
token_count: int
|
||||
document_id: str
|
||||
|
||||
|
|
@ -62,6 +62,6 @@ class Memory(Protocol):
|
|||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
|
|
|
|||
|
|
@ -5,16 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ViolationLevel(Enum):
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue