|
|
|
@ -55,8 +55,8 @@ def train(accelerator, config):
|
|
|
|
|
|
|
|
|
|
with accelerator.main_process_first():
|
|
|
|
|
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint = config["gradient_checkpointing"]
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
|
|
|
|
use_cache=False if checkpoint else True,
|
|
|
|
@ -164,7 +164,7 @@ def train(accelerator, config):
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
unwrapped_model.push_to_hub(config["save_name"], private=True)
|
|
|
|
|
unwrapped_model.push_to_hub(config["save_name"] + "_first_epoch", private=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|