-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Simple example to demonstrate parameter server training pattern #705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| # --------- Helper Methods -------------------- | ||
|
|
||
| # On the local node, call a method with first arg as the value held by the RRef. Other args are passed in as arguments to the function called. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we break long comments into multiple lines?
| return method(rref.local_value(), *args, **kwargs) | ||
|
|
||
| # Syncrhnous RPC to run a method remotely and get a result. The method should be a class method corresponding to | ||
| # Given an RRef, return the result of calling the passed in method on the value held by the RRef. This call is done on the remote node that owns the RRef. args and kwargs are passed into the method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| def call_method(method, rref, *args, **kwargs): | ||
| return method(rref.local_value(), *args, **kwargs) | ||
|
|
||
| # Syncrhnous RPC to run a method remotely and get a result. The method should be a class method corresponding to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Syncrhnous -> Synchronous
| print(loss) | ||
| dist_autograd.backward([loss]) | ||
| param_rrefs = net.get_global_param_rrefs() | ||
| opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to create an opt per iteration? Can this be done before the for loop?
mrshenli
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for putting together this example!
| ])), | ||
| batch_size=32, shuffle=True, ) | ||
| processes = [] | ||
| # Run num_trainers workers, plus 1 for the parameter serever. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
serever -> server
| dist_autograd.backward([loss]) | ||
| param_rrefs = net.get_global_param_rrefs() | ||
| opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03) | ||
| opt.step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is hogwild, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, should we use locks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hogwild is good, as long as we clearly state it. :)
mrshenli
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just need a few more minor edits before landing I think. Thanks @rohan-varma !
| "world_size", | ||
| type=int, | ||
| default=4, | ||
| help="Total number of participating processes. Should be the sum of master node and all training nodes, add 1 if creating training node on master.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we break long lines into shorter ones? There are a few more below.
| x = torch.flatten(x, 1) | ||
| # need to put this on CUDA | ||
| next_device = next(self.fc1.parameters()).device | ||
| # print("In forward, changing device to {}".format(str(next_device))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is vestige from past debugging code?
| return method(rref.local_value(), *args, **kwargs) | ||
|
|
||
| # Synchronous RPC to run a method remotely and get a result. | ||
| # The method should be a class method corresponding to Given an RRef, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be some missing text between "corresponding to" and "Given".
Or
Given an -> a given, and then fix the following clause?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove this portion and start with "Given an RRef" since the former point is covered below.
| # construct it once | ||
| param_server = ParameterServer(num_gpus=num_gpus) | ||
| print( | ||
| "Returning parameter server with ID {}".format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like we can fit this in one line with f"Returning parameter server with ID {id(param_server)}"?
|
@mrshenli Updated to address all comments, and also to explicitly move tensors in and out of GPU so that they work with the latest RPC, which disallows sending CUDA tensors. Could you take another look? Thanks! |
mrshenli
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! @jlin27 shall we land this example?
Simple example to demonstrate parameter server training pattern
torch.distributed.rpc now enables parameter-server style training in pytorch with RPC-based APIs. This PR adds a simple parameter-server training example that launches a bunch of trainers and a single PS to train a single model for mnist.