|
|
|
@ -20,7 +20,7 @@ def main():
|
|
|
|
|
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
|
|
|
|
|
parser.add_argument("--task", type=str, default="cls", help="Training task type")
|
|
|
|
|
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
|
|
|
|
parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
|
|
|
|
|
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
|
|
|
|
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
|
|
|
|
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
|
|
|
|
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
|
|
|
|
@ -36,14 +36,18 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
args.n_processes = int(args.n_processes)
|
|
|
|
|
|
|
|
|
|
processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)]
|
|
|
|
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
|
|
|
|
processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
|
|
|
|
for proc in processes:
|
|
|
|
|
proc.start()
|
|
|
|
|
for proc in processes:
|
|
|
|
|
proc.join()
|
|
|
|
|
|
|
|
|
|
fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)
|
|
|
|
|
logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}")
|
|
|
|
|
|
|
|
|
|
def benchmark_training(process_idx, args):
|
|
|
|
|
|
|
|
|
|
def benchmark_training(process_idx, args, result_pipe):
|
|
|
|
|
if args.task == "cls":
|
|
|
|
|
model = AutoDistributedModelForSequenceClassification.from_pretrained(
|
|
|
|
|
args.model,
|
|
|
|
@ -96,7 +100,7 @@ def benchmark_training(process_idx, args):
|
|
|
|
|
bwd_speed = input_ids.numel() / np.mean(bwd_times)
|
|
|
|
|
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
|
|
|
|
|
|
|
|
|
|
logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
|
|
|
|
|
result_pipe.send((fwd_speed, bwd_speed))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|