-
Notifications
You must be signed in to change notification settings - Fork 11
Performance issue with models running on CPU #130
Description
Problem
@norci mentioned in JuliaReinforcementLearning/ReinforcementLearningZoo.jl#87 (comment) that, there may be some potential performance improvements with algorithms running on CPU only.
Currently the experience buffer is using CircularArrayBuffer to store data. When doing batch updating, we use select_last_dim function to create a view. But according to the doc: Copying-data-is-not-always-bad, it may be faster to turn the view into an Array first before feeding it into Flux models.
Initial investigation shows that, by transforming the SubArray into Array, the average time per step of experiment E`JuliaRL_BasicDQN_CartPole` will decrease from ~0.00128 to ~0.00107. When the model is more complex, the improvement becomes larger.
Models on GPU are not affected
Note that models on GPU will not be affected, since SubArray will be automatically converted to Array first:
And we have already forced the SubArray of Array to be converted into CuArray instead of SubArray of CuArray here:
ReinforcementLearningCore.jl/src/utils/device.jl
Lines 16 to 32 in a94544c
| send_to_device( | |
| ::Val{:gpu}, | |
| x::Union{ | |
| SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, | |
| Base.ReshapedArray{<:Any,<:Any,<:SubArray{<:Any,<:Any,<:CircularArrayBuffer}}, | |
| SubArray{ | |
| <:Any, | |
| <:Any, | |
| <:Base.ReshapedArray{ | |
| <:Any, | |
| <:Any, | |
| <:SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, | |
| }, | |
| }, | |
| ElasticArray, | |
| }, | |
| ) = CuArray(x) |
Possible Solutions
- Nothing to change (But need to document it somewhere)
Users need to manually add a layer in models to convert the SubArray into Array first when working in CPU only devices .
- Automacally convert
SubArrayofArrayintoArrayin thesend_to_hostfunction.
This is the easiest way. But I think it breaks the meaning of send_to_host. Afterall, the SubArray of Array is already in CPU.