From d40e5274714f9cf385057676dde9aec019e310a4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 26 Dec 2024 18:10:23 -0800 Subject: [PATCH] bedrock --- .../remote/inference/bedrock/bedrock.py | 25 ++++++++++++++----- .../remote/inference/cerebras/cerebras.py | 22 +++++++++++++--- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index ddf59fda8..d340bbbea 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -4,8 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import * # noqa: F403 import json +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient from llama_models.datatypes import CoreModelId @@ -13,6 +13,24 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig +from llama_stack.providers.utils.bedrock.client import create_bedrock_client + from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, @@ -29,11 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -from llama_stack.apis.inference import * # noqa: F403 - -from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -from llama_stack.providers.utils.bedrock.client import create_bedrock_client - MODEL_ALIASES = [ build_model_alias( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 2ff213c2e..40457e1ae 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -4,17 +4,31 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.apis.inference import * # noqa: F403 - -from llama_models.datatypes import CoreModelId +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + CompletionRequest, + CompletionResponse, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( build_model_alias,