Distributed Setting

HetSeq can be executed on single GPU on a single node, multiple GPUs on a single node, or multiple GPUs across multiple nodes. Main logic is defined at train.py.

Control Parameters

--distributed-init-method: defines an initialization.

  • tcp:// (IP address:port. TCP example for multiple nodes) or

  • file:///hetseq/communicate.txt (shared file example for multiple nodes).

--distributed-world-size: total number of GPUs used in the training.

--distributed-gpus: the number of GPUs on the current node.

--distributed-rank: represents the rank/index of the first GPU used on current node.

Different Distributed Settings

1.Single GPU:

$ --distributed-world-size 1 --device-id 1
  1. Four GPUs on a single node:

$ --distributed-world-size 4
  1. Four nodes with four GPUs each (16 GPUs in total) is the IP address of first node and 11111 is the port number:

  • 1st node

$ --distributed-init-method tcp:// --distributed-world-size 16 --distributed-gpus 4 --distributed-rank 0
  • 2nd node

$ --distributed-init-method tcp:// --distributed-world-size 16 --distributed-gpus 4 --distributed-rank 4
  • 3rd node

$ --distributed-init-method tcp:// --distributed-world-size 16 --distributed-gpus 4 --distributed-rank 8
  • 4th node

$ --distributed-init-method tcp:// --distributed-world-size 16 --distributed-gpus 4 --distributed-rank 12

Main Logic

if args.distributed_init_method is not None:
        assert args.distributed_gpus <= torch.cuda.device_count()

        if args.distributed_gpus > 1 and not args.distributed_no_spawn:
                start_rank = args.distributed_rank
                args.distributed_rank = None  # assign automatically
                        args=(args, start_rank),
                distributed_main(args.device_id, args)

elif args.distributed_world_size > 1:
        assert args.distributed_world_size <= torch.cuda.device_count()
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
        args.distributed_rank = None  # set based on device id
                args=(args, ),