mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Remove testing code
This commit is contained in:
parent
a8a860ea1f
commit
d8c4e7da4b
1 changed files with 5 additions and 8 deletions
|
@ -21,14 +21,11 @@ from .config import OpenAIImplConfig
|
||||||
|
|
||||||
|
|
||||||
class OpenAIInferenceAdapter(Inference):
|
class OpenAIInferenceAdapter(Inference):
|
||||||
|
max_tokens: int
|
||||||
|
model_id: str
|
||||||
|
|
||||||
def __init__(self, config: OpenAIImplConfig) -> None:
|
def __init__(self, config: OpenAIImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# For testing purposes
|
|
||||||
# This model's maximum context length is 6144 tokens.
|
|
||||||
self.max_tokens = 6144
|
|
||||||
self.model_id = "mistral-7b-instruct"
|
|
||||||
|
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
|
@ -66,7 +63,7 @@ class OpenAIInferenceAdapter(Inference):
|
||||||
|
|
||||||
def resolve_openai_model(self, model_name: str) -> str:
|
def resolve_openai_model(self, model_name: str) -> str:
|
||||||
# TODO: This should be overriden by other classes
|
# TODO: This should be overriden by other classes
|
||||||
return self.model_id
|
return model_name
|
||||||
|
|
||||||
def get_openai_chat_options(self, request: ChatCompletionRequest) -> dict:
|
def get_openai_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||||
options = {}
|
options = {}
|
||||||
|
@ -106,7 +103,7 @@ class OpenAIInferenceAdapter(Inference):
|
||||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||||
|
|
||||||
input_tokens = len(model_input.tokens)
|
input_tokens = len(model_input.tokens)
|
||||||
# TODO: There is a potential bug here
|
# TODO: There is a potential bug here to be investigated
|
||||||
# max_new_tokens = min(
|
# max_new_tokens = min(
|
||||||
# request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
# request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||||
# self.max_tokens - input_tokens - 1,
|
# self.max_tokens - input_tokens - 1,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue