README ¶
MPI Message Passing Interface Example
This is a version of the ra25 example that uses MPI to distributed computation across multiple processors (procs). See Wiki MPI for more info.
N completely separate instances of the same simulation program are run in parallel, and they communicate weight changes and trial-level log data amongst themselves. Each proc thus trains on a subset of the total set of training patterns for each epoch. Thus, dividing the patterns across procs is the most difficult aspect of making this work. The mechanics of synchronizing the weight changes and etable data are just a few simple method calls.
Speedups approach linear, because the synchronization is relatively infrequent, especially for larger networks which have more computation per trial. The biggest cost in MPI is the latency: sending a huge list of weight changes infrequently is much faster overall than sending smaller amounts of data more frequently.
You can only use MPI for running in nogui mode, using command-line args -- otherwise you'd get multiple copies of the GUI running..
Building and running
To build with actual mpi support, you must do:
$ go build -tags mpi
otherwise it builds with a dummy version of mpi that doesn't actually do anything (convenient for enabling both MPI and non-MPI support in one codebase). Always ensure that your code does something reasonable when mpi.WorldSize() == 1 -- that is what the dummy code returns Also you should use a UseMPI
flag, set by the -mpi
command line arg, to do different things depending -- e.g., don't try to aggregate DWts if not using MPI, as it will waste a lot of time and accomplish nothing.
To run, do something like this:
$ mpirun -np 2 ./mpi -mpi
The number of processors must divide into 24 for this example (number of patterns used in ra25) evenly (2, 3, 4, 6, 8).
General tips for MPI usage
-
MOST IMPORTANT: all procs must remain completely synchronized in terms of when they call MPI functions -- these functions will block until all procs have called the same function. The default behavior of setting a saved random number seed for all procs should ensure this. But you also need to make sure that the same random permutation of item lists, etc takes place across all nodes. The
empi.FixedTable
environment does this for the case of a table with a set of patterns. -
Instead of aggregating epoch-level stats directly on the Sim, which is how the basic ra25 example works, you need to record trial level data in an etable (
TrnTrlLog
), then synchronize that across all procs at the end of the epoch, and run aggregation stats on that data. This is how the testing trial -> epoch process works in ra25 already.
Key Diffs from ra25
Here are the main diffs that transform the ra25.go example into this mpi version:
-
Search for
mpi
in the code (case insensitive) -- most of the changes have that in or near them. -
Most of the changes are the bottom of the file.
main() Config() call
At the top of the file, it can be important to configure TheSim
after mpi has been initialized, if there are things that are done differently there -- thus, you should move the TheSim.Config()
call into CmdArgs
:
func main() {
TheSim.New() // note: not running Config here -- done in CmdArgs for mpi / nogui
if len(os.Args) > 1 {
TheSim.CmdArgs() // simple assumption is that any args = no gui -- could add explicit arg if you want
} else {
TheSim.Config() // for GUI case, config then run..
gimain.Main(func() { // this starts gui -- requires valid OpenGL display connection (e.g., X11)
guirun()
})
}
}
Sim struct
There are some other things added but they are just more of what is already there -- these are the uniquely MPI parts, at end of Sim struct type:
UseMPI bool `view:"-" desc:"if true, use MPI to distribute computation across nodes"`
Comm *mpi.Comm `view:"-" desc:"mpi communicator"`
AllDWts []float32 `view:"-" desc:"buffer of all dwt weight changes -- for mpi sharing"`
SumDWts []float32 `view:"-" desc:"buffer of MPI summed dwt weight changes"`
ThetaCyc
Now call the MPI version of WtFmDWt, which sums weight changes across procs:
if train {
ss.MPIWtFmDWt() // special MPI version
}
Allocating Patterns Across Nodes
In ConfigEnv
, non-overlapping subsets of input patterns are allocated to different nodes, so that each epoch has the same full set of input patterns as with one processor.
ss.TrainEnv.Table = etable.NewIdxView(ss.Pats)
if ss.UseMPI {
st, ed, _ := empi.AllocN(ss.Pats.Rows)
ss.TrainEnv.Table.Idxs = ss.TrainEnv.Table.Idxs[st:ed]
}
Logging
The elog
system has support for gathering all of the rows of Trial-level logs from each of the different processors into a combined table, which is then used for aggregating stats at the Epoch level. To enable all the standard infrastructure to work in the same way as in the non-MPI case, the aggregated table is set as the Trial log. This means that after we do the aggregation of the Trial data (at the Epoch level), we need to reset the number of rows back to the original number present per each processor, otherwise the table grows exponentially! If the Trial data is always accumulated by adding rows and resetting back to 0 at the end of the epoch, then you would just do that as usual.
Here's the relevant code in the Log()
method:
var ntrow int
if ss.UseMPI && time == elog.Epoch { // Must gather data for trial level if doing epoch level
ss.Logs.MPIGatherTableRows(mode, elog.Trial, ss.Comm)
}
...
if ss.UseMPI && time == elog.Epoch { // Must reset rows back to original number pre-gather!
dt := ss.Logs.Table(mode, elog.Trial)
dt.SetNumRows(ntrow)
}
LogFileName
New version optionally adds the rank
of the processor if not root -- sometimes it is useful for debugging or full stats to record a log for each processor, instead of just doing the root (0) one, which is the default.
// LogFileName returns default log file name
func (ss *Sim) LogFileName(lognm string) string {
nm := ss.Net.Nm + "_" + ss.RunName() + "_" + lognm
if mpi.WorldRank() > 0 {
nm += fmt.Sprintf("_%d", mpi.WorldRank())
}
nm += ".csv"
return nm
}
CmdArgs
At the end, CmdArgs sets the UseMPI
flag based on the -mpi
arg, and has quite a bit of MPI-specific logic in it, which we don't reproduce here -- see ra25.go
code and look for mpi.
We use mpi.Printf
instead of fmt.Printf
to have it only print on the root node, so you don't get a bunch of duplicated messages.
MPI Code
The main MPI-specific code is at the end, reproduced here for easy reference. NOTE: please always use the code in ra25.go as a copy-paste source as there might be a few small changes, which will be more closely tracked there than here.
// MPIInit initializes MPI
func (ss *Sim) MPIInit() {
mpi.Init()
var err error
ss.Comm, err = mpi.NewComm(nil) // use all procs
if err != nil {
log.Println(err)
ss.UseMPI = false
} else {
mpi.Printf("MPI running on %d procs\n", mpi.WorldSize())
}
}
// MPIFinalize finalizes MPI
func (ss *Sim) MPIFinalize() {
if ss.UseMPI {
mpi.Finalize()
}
}
// CollectDWts collects the weight changes from all synapses into AllDWts
// includes all other long adapting factors too: DTrgAvg, ActAvg, etc
func (ss *Sim) CollectDWts(net *axon.Network) {
net.CollectDWts(&ss.AllDWts)
}
// MPIWtFmDWt updates weights from weight changes, using MPI to integrate
// DWt changes across parallel nodes, each of which are learning on different
// sequences of inputs.
func (ss *Sim) MPIWtFmDWt() {
if ss.UseMPI {
ss.CollectDWts(ss.Net)
ndw := len(ss.AllDWts)
if len(ss.SumDWts) != ndw {
ss.SumDWts = make([]float32, ndw)
}
ss.Comm.AllReduceF32(mpi.OpSum, ss.SumDWts, ss.AllDWts)
ss.Net.SetDWts(ss.SumDWts, mpi.WorldSize())
}
ss.Net.WtFmDWt(&ss.Time)
}