Skip to content

Conversation

@rohan-varma
Copy link
Contributor

@rohan-varma rohan-varma commented Feb 7, 2020

torch.distributed.rpc now enables parameter-server style training in pytorch with RPC-based APIs. This PR adds a simple parameter-server training example that launches a bunch of trainers and a single PS to train a single model for mnist.


# --------- Helper Methods --------------------

# On the local node, call a method with first arg as the value held by the RRef. Other args are passed in as arguments to the function called.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we break long comments into multiple lines?

return method(rref.local_value(), *args, **kwargs)

# Syncrhnous RPC to run a method remotely and get a result. The method should be a class method corresponding to
# Given an RRef, return the result of calling the passed in method on the value held by the RRef. This call is done on the remote node that owns the RRef. args and kwargs are passed into the method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

def call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)

# Syncrhnous RPC to run a method remotely and get a result. The method should be a class method corresponding to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Syncrhnous -> Synchronous

print(loss)
dist_autograd.backward([loss])
param_rrefs = net.get_global_param_rrefs()
opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to create an opt per iteration? Can this be done before the for loop?

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for putting together this example!

])),
batch_size=32, shuffle=True, )
processes = []
# Run num_trainers workers, plus 1 for the parameter serever.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serever -> server

dist_autograd.backward([loss])
param_rrefs = net.get_global_param_rrefs()
opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03)
opt.step()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hogwild, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, should we use locks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hogwild is good, as long as we clearly state it. :)

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just need a few more minor edits before landing I think. Thanks @rohan-varma !

"world_size",
type=int,
default=4,
help="Total number of participating processes. Should be the sum of master node and all training nodes, add 1 if creating training node on master.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we break long lines into shorter ones? There are a few more below.

x = torch.flatten(x, 1)
# need to put this on CUDA
next_device = next(self.fc1.parameters()).device
# print("In forward, changing device to {}".format(str(next_device)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is vestige from past debugging code?

return method(rref.local_value(), *args, **kwargs)

# Synchronous RPC to run a method remotely and get a result.
# The method should be a class method corresponding to Given an RRef,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be some missing text between "corresponding to" and "Given".

Or

Given an -> a given, and then fix the following clause?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove this portion and start with "Given an RRef" since the former point is covered below.

# construct it once
param_server = ParameterServer(num_gpus=num_gpus)
print(
"Returning parameter server with ID {}".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we can fit this in one line with f"Returning parameter server with ID {id(param_server)}"?

@rohan-varma
Copy link
Contributor Author

@mrshenli Updated to address all comments, and also to explicitly move tensors in and out of GPU so that they work with the latest RPC, which disallows sending CUDA tensors. Could you take another look? Thanks!

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @jlin27 shall we land this example?

@rohan-varma rohan-varma changed the title [WIP] Simple example to demonstrate parameter server training pattern Simple example to demonstrate parameter server training pattern Mar 21, 2020
@jlin27 jlin27 merged commit 8a5b379 into pytorch:master Mar 23, 2020
YinZhengxun pushed a commit to YinZhengxun/mt-exercise-02 that referenced this pull request Mar 30, 2025
Simple example to demonstrate parameter server training pattern
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants