diff --git a/GPT/query.py b/GPT/query.py index 6cf8fea..642d40d 100644 --- a/GPT/query.py +++ b/GPT/query.py @@ -65,12 +65,15 @@ def get_stream_prompt(query, prompt_file, isQuestion, info_file=None): openai.api_key = API_KEY if isQuestion: data = util.read_json(INFO.BRAIN_DATA) - result = GPT.gpt_tools.search_chunks(query, data, count=1) - my_info = util.read_file(info_file) - prompt = util.read_file(prompt_file) - prompt = prompt.replace('<>', result[0]['content']) - prompt = prompt.replace('<>', query) - prompt = prompt.replace('<>', my_info) + if data: + result = GPT.gpt_tools.search_chunks(query, data, count=1) + my_info = util.read_file(info_file) + prompt = util.read_file(prompt_file) + prompt = prompt.replace('<>', result[0]['content']) + prompt = prompt.replace('<>', query) + prompt = prompt.replace('<>', my_info) + else: + prompt = '' else: chunk = textwrap.wrap(query, 10000)[0] prompt = util.read_file(prompt_file).replace('<>', chunk) diff --git a/Seanium_Brain.py b/Seanium_Brain.py index e4d7605..a91ec4d 100644 --- a/Seanium_Brain.py +++ b/Seanium_Brain.py @@ -157,7 +157,12 @@ with body: else: max_model_token = 2048 - st.markdown(f'Prompt token: `{st_tool.predict_token(query, prompt_core)}/{max_model_token}`') + tokens, isTokenZero = st_tool.predict_token(query, prompt_core) + token_panel = st.empty() + if isTokenZero: + token_panel.markdown('Prompt token: `Not Available`') + else: + token_panel.markdown(f'Prompt token: `{tokens}/{max_model_token}`') if send: st_tool.execute_brain(query, param, diff --git a/streamlit_toolkit/tools.py b/streamlit_toolkit/tools.py index 43d78cf..f34df7e 100644 --- a/streamlit_toolkit/tools.py +++ b/streamlit_toolkit/tools.py @@ -19,13 +19,14 @@ SESSION_TIME = st.session_state['SESSION_TIME'] CURRENT_LOG_FILE = f'{INFO.LOG_PATH}/log_{SESSION_TIME}.log' -def predict_token(query: str, prompt_core: GPT.model.prompt_core) -> int: +def predict_token(query: str, prompt_core: GPT.model.prompt_core) -> (int, bool): """predict how many tokens to generate""" llm = OpenAI() - token = llm.get_num_tokens(GPT.query.get_stream_prompt(query, prompt_file=prompt_core.question, - isQuestion=True, - info_file=prompt_core.my_info)) - return token + prompt = GPT.query.get_stream_prompt(query, prompt_file=prompt_core.question, + isQuestion=True, + info_file=prompt_core.my_info) + token = llm.get_num_tokens(prompt) + return token, token == 0 def create_log():