Skip to content

Conversation

@mspronesti
Copy link
Contributor

@mspronesti mspronesti commented Dec 29, 2022

Hi,
this PR aims at integrating the DPM Solver Multistep Scheduler in this Rust version, as requested in #17 .

A couple of disclaimers:

  • the implementation contains all the features of the Python version, but the optional thresholding, described in this paper. I plan to add it in a future PR, if you agree.

  • I don't like the initialization of model_outputs. Unfortunately tch::Tensor doesn't implement the Copy trait, so I couldn't find a leaner way. Feel free to leave a suggestion if you have a better idea!

  • differently from the other two schedulers, the step method requires self to be mutable as I need to update the attribute model_outputs.
    Therefore, to run the examples, make sure to create a mut scheduler:

     let mut scheduler = DPMSolverMultistepScheduler::new(n_steps, Default::default());

    and to call to_owned when iterating over the timesteps (e.g. here):

     for (timestep_index, &timestep) in scheduler.timesteps().to_owned().iter().enumerate() 

Cheers 😄 .

Copy link
Owner

@LaurentMazare LaurentMazare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, looks neat!
Could you give a bit more details about how this was tested, if it's expected to be fully in line with the python version etc?

}

pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
#[rustfmt::skip]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cosmetic but could you avoid overriding the rustfmt setting? It's never the case in this or the tch crate and even if it might look weird at some point (and I would personally dislike some formatting), it makes things much easier if someone else edits the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I did it to leave the chain of calls on multiple lines, otherwise rustfmt would force me to put everything on one line. If you don't like it, I will fix this in the next commit :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I really don't mind the format one way or another, I'm just opiniated towards avoiding manual formatting and so going to whatever rustfmt decides (and again I might certainly dislike some of the outputs but it's great not to have to think about manual formatting anymore :) ).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to keep this format without the action failing ?
Otherwise I will just accept the format enforced by rustfmt, it ain't a big deal after all :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I would just go with what rusfmt suggests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, maybe I missed something but the rustfmt skip seems to still be there? (once it has been removed/reformatted, happy to merge)

Copy link
Contributor Author

@mspronesti mspronesti Dec 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just committed. Hope it contains everything we discussed. Let me know if I missed anything :)

@mspronesti
Copy link
Contributor Author

mspronesti commented Dec 31, 2022

I just noticed I missed your first question. I "tested" the implementation using the snippet I shared with you in my last PR (changing the scheduler, clearly :)).
The codebase should be almost 1:1 with the python version (or, at least, that's what I tried to do).

@LaurentMazare
Copy link
Owner

Good, could you maybe add a few permalink to the python version in the scheduler code? This is nice for people that would want to see how the code was converted.

@mspronesti
Copy link
Contributor Author

mspronesti commented Dec 31, 2022

Sure, is there a specific part you suggest to "annotate" with a permalink ?
For instance, here I could replace this comment with a permalink to the python version :)
The rest of the code seems (to me) fairly similar to the python implementation, including the comments and the name of the variables.

* remove rustfmt:skip
* add permalinks to the python implementation
* re-run rustfmt
@mspronesti
Copy link
Contributor Author

mspronesti commented Dec 31, 2022

Not sure why one of the checks failed. Seems the classic error one gets when .../cargo/index is corrupted and just needs to be deleted and re-created, or perhaps a network error. Maybe re-running the action solves it ?

@LaurentMazare LaurentMazare merged commit f5f54ac into LaurentMazare:main Dec 31, 2022
@LaurentMazare
Copy link
Owner

Rerunning fixed indeed this test, I've merged the changes, thanks for the PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants