You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
talk-codebase/talk_codebase/config.py

163 lines
3.8 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import gpt4all
import openai
import questionary
import yaml
from talk_codebase.consts import MODEL_TYPES
config_path = os.path.join(os.path.expanduser("~"), ".talk_codebase_config.yaml")
def get_config():
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = yaml.safe_load(f)
else:
config = {}
return config
def save_config(config):
with open(config_path, "w") as f:
yaml.dump(config, f)
def api_key_is_invalid(api_key):
if not api_key:
return True
try:
openai.api_key = api_key
openai.Engine.list()
except Exception:
return True
return False
def get_gpt_models(openai):
try:
model_lst = openai.Model.list()
except Exception:
print("✘ Failed to retrieve model list")
return []
return [i['id'] for i in model_lst['data'] if 'gpt' in i['id']]
def configure_model_name_openai(config):
api_key = config.get("api_key")
if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("openai_model_name"):
return
openai.api_key = api_key
gpt_models = get_gpt_models(openai)
choices = [{"name": model, "value": model} for model in gpt_models]
if not choices:
print(" No GPT models available")
return
model_name = questionary.select("🤖 Select model name:", choices).ask()
if not model_name:
print("✘ No model selected")
return
config["openai_model_name"] = model_name
save_config(config)
print("🤖 Model name saved!")
def remove_model_name_openai():
config = get_config()
config["openai_model_name"] = None
save_config(config)
def configure_model_name_local(config):
if config.get("model_type") != MODEL_TYPES["LOCAL"] or config.get("local_model_name"):
return
list_models = gpt4all.GPT4All.list_models()
def get_model_info(model):
return (
f"{model['name']} "
f"| {model['filename']} "
f"| {model['filesize']} "
f"| {model['parameters']} "
f"| {model['quant']} "
f"| {model['type']}"
)
choices = [
{"name": get_model_info(model), "value": model['filename']} for model in list_models
]
model_name = questionary.select("🤖 Select model name:", choices).ask()
config["local_model_name"] = model_name
save_config(config)
print("🤖 Model name saved!")
def remove_model_name_local():
config = get_config()
config["local_model_name"] = None
save_config(config)
def get_and_validate_api_key():
prompt = "🤖 Enter your OpenAI API key: "
api_key = input(prompt)
while api_key_is_invalid(api_key):
print("✘ Invalid API key")
api_key = input(prompt)
return api_key
def configure_api_key(config):
if config.get("model_type") != MODEL_TYPES["OPENAI"]:
return
if api_key_is_invalid(config.get("api_key")):
api_key = get_and_validate_api_key()
config["api_key"] = api_key
save_config(config)
def remove_api_key():
config = get_config()
config["api_key"] = None
save_config(config)
def remove_model_type():
config = get_config()
config["model_type"] = None
save_config(config)
def configure_model_type(config):
if config.get("model_type"):
return
model_type = questionary.select(
"🤖 Select model type:",
choices=[
{"name": "Local", "value": MODEL_TYPES["LOCAL"]},
{"name": "OpenAI", "value": MODEL_TYPES["OPENAI"]},
]
).ask()
config["model_type"] = model_type
save_config(config)
CONFIGURE_STEPS = [
configure_model_type,
configure_api_key,
configure_model_name_openai,
configure_model_name_local,
]