diff --git a/docs/source/_static/img/replaybuffer_traj.png b/docs/source/_static/img/replaybuffer_traj.png new file mode 100644 index 00000000000..64773ee8f78 Binary files /dev/null and b/docs/source/_static/img/replaybuffer_traj.png differ diff --git a/docs/source/_static/js/theme.js b/docs/source/_static/js/theme.js index 219443ee11e..297154d9ed7 100644 --- a/docs/source/_static/js/theme.js +++ b/docs/source/_static/js/theme.js @@ -692,7 +692,7 @@ window.sideMenus = { } }; -},{}],11:[function(require,module,exports){ +},{}],"pytorch-sphinx-theme":[function(require,module,exports){ var jQuery = (typeof(window) != 'undefined') ? window.jQuery : require('jquery'); // Sphinx theme nav state @@ -1125,3824 +1125,4 @@ $(window).scroll(function () { }); -},{"jquery":"jquery"}],"pytorch-sphinx-theme":[function(require,module,exports){ -require=(function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c="function"==typeof require&&require;if(!f&&c)return c(i,!0);if(u)return u(i,!0);var a=new Error("Cannot find module '"+i+"'");throw a.code="MODULE_NOT_FOUND",a}var p=n[i]={exports:{}};e[i][0].call(p.exports,function(r){var n=e[i][1][r];return o(n||r)},p,p.exports,r,e,n,t)}return n[i].exports}for(var u="function"==typeof require&&require,i=0;i wait) { - if (timeout) { - clearTimeout(timeout); - timeout = null; - } - previous = now; - result = func.apply(context, args); - if (!timeout) context = args = null; - } else if (!timeout && options.trailing !== false) { - timeout = setTimeout(later, remaining); - } - return result; - }; - }, - - closest: function (el, selector) { - var matchesFn; - - // find vendor prefix - ['matches','webkitMatchesSelector','mozMatchesSelector','msMatchesSelector','oMatchesSelector'].some(function(fn) { - if (typeof document.body[fn] == 'function') { - matchesFn = fn; - return true; - } - return false; - }); - - var parent; - - // traverse parents - while (el) { - parent = el.parentElement; - if (parent && parent[matchesFn](selector)) { - return parent; - } - el = parent; - } - - return null; - }, - - // Modified from https://stackoverflow.com/a/18953277 - offset: function(elem) { - if (!elem) { - return; - } - - rect = elem.getBoundingClientRect(); - - // Make sure element is not hidden (display: none) or disconnected - if (rect.width || rect.height || elem.getClientRects().length) { - var doc = elem.ownerDocument; - var docElem = doc.documentElement; - - return { - top: rect.top + window.pageYOffset - docElem.clientTop, - left: rect.left + window.pageXOffset - docElem.clientLeft - }; - } - }, - - headersHeight: function() { - if (document.getElementById("pytorch-left-menu").classList.contains("make-fixed")) { - return document.getElementById("pytorch-page-level-bar").offsetHeight; - } else { - return document.getElementById("header-holder").offsetHeight + - document.getElementById("pytorch-page-level-bar").offsetHeight; - } - }, - - windowHeight: function() { - return window.innerHeight || - document.documentElement.clientHeight || - document.body.clientHeight; - } -} - -},{}],2:[function(require,module,exports){ -var cookieBanner = { - init: function() { - cookieBanner.bind(); - - var cookieExists = cookieBanner.cookieExists(); - - if (!cookieExists) { - cookieBanner.setCookie(); - cookieBanner.showCookieNotice(); - } - }, - - bind: function() { - $(".close-button").on("click", cookieBanner.hideCookieNotice); - }, - - cookieExists: function() { - var cookie = localStorage.getItem("returningPytorchUser"); - - if (cookie) { - return true; - } else { - return false; - } - }, - - setCookie: function() { - localStorage.setItem("returningPytorchUser", true); - }, - - showCookieNotice: function() { - $(".cookie-banner-wrapper").addClass("is-visible"); - }, - - hideCookieNotice: function() { - $(".cookie-banner-wrapper").removeClass("is-visible"); - } -}; - -$(function() { - cookieBanner.init(); -}); - -},{}],3:[function(require,module,exports){ -window.filterTags = { - bind: function() { - var options = { - valueNames: [{ data: ["tags"] }], - page: "6", - pagination: true - }; - - var tutorialList = new List("tutorial-cards", options); - - function filterSelectedTags(cardTags, selectedTags) { - return cardTags.some(function(tag) { - return selectedTags.some(function(selectedTag) { - return selectedTag == tag; - }); - }); - } - - function updateList() { - var selectedTags = []; - - $(".selected").each(function() { - selectedTags.push($(this).data("tag")); - }); - - tutorialList.filter(function(item) { - var cardTags; - - if (item.values().tags == null) { - cardTags = [""]; - } else { - cardTags = item.values().tags.split(","); - } - - if (selectedTags.length == 0) { - return true; - } else { - return filterSelectedTags(cardTags, selectedTags); - } - }); - } - - $(".filter-btn").on("click", function() { - if ($(this).data("tag") == "all") { - $(this).addClass("all-tag-selected"); - $(".filter").removeClass("selected"); - } else { - $(this).toggleClass("selected"); - $("[data-tag='all']").removeClass("all-tag-selected"); - } - - // If no tags are selected then highlight the 'All' tag - - if (!$(".selected")[0]) { - $("[data-tag='all']").addClass("all-tag-selected"); - } - - updateList(); - }); - } -}; - -},{}],4:[function(require,module,exports){ -// Modified from https://stackoverflow.com/a/32396543 -window.highlightNavigation = { - navigationListItems: document.querySelectorAll("#pytorch-right-menu li"), - sections: document.querySelectorAll(".pytorch-article .section"), - sectionIdTonavigationLink: {}, - - bind: function() { - if (!sideMenus.displayRightMenu) { - return; - }; - - for (var i = 0; i < highlightNavigation.sections.length; i++) { - var id = highlightNavigation.sections[i].id; - highlightNavigation.sectionIdTonavigationLink[id] = - document.querySelectorAll('#pytorch-right-menu li a[href="#' + id + '"]')[0]; - } - - $(window).scroll(utilities.throttle(highlightNavigation.highlight, 100)); - }, - - highlight: function() { - var rightMenu = document.getElementById("pytorch-right-menu"); - - // If right menu is not on the screen don't bother - if (rightMenu.offsetWidth === 0 && rightMenu.offsetHeight === 0) { - return; - } - - var scrollPosition = utilities.scrollTop(); - var OFFSET_TOP_PADDING = 25; - var offset = document.getElementById("header-holder").offsetHeight + - document.getElementById("pytorch-page-level-bar").offsetHeight + - OFFSET_TOP_PADDING; - - var sections = highlightNavigation.sections; - - for (var i = (sections.length - 1); i >= 0; i--) { - var currentSection = sections[i]; - var sectionTop = utilities.offset(currentSection).top; - - if (scrollPosition >= sectionTop - offset) { - var navigationLink = highlightNavigation.sectionIdTonavigationLink[currentSection.id]; - var navigationListItem = utilities.closest(navigationLink, "li"); - - if (navigationListItem && !navigationListItem.classList.contains("active")) { - for (var i = 0; i < highlightNavigation.navigationListItems.length; i++) { - var el = highlightNavigation.navigationListItems[i]; - if (el.classList.contains("active")) { - el.classList.remove("active"); - } - } - - navigationListItem.classList.add("active"); - - // Scroll to active item. Not a requested feature but we could revive it. Needs work. - - // var menuTop = $("#pytorch-right-menu").position().top; - // var itemTop = navigationListItem.getBoundingClientRect().top; - // var TOP_PADDING = 20 - // var newActiveTop = $("#pytorch-side-scroll-right").scrollTop() + itemTop - menuTop - TOP_PADDING; - - // $("#pytorch-side-scroll-right").animate({ - // scrollTop: newActiveTop - // }, 100); - } - - break; - } - } - } -}; - -},{}],5:[function(require,module,exports){ -window.mainMenuDropdown = { - bind: function() { - $("[data-toggle='ecosystem-dropdown']").on("click", function() { - toggleDropdown($(this).attr("data-toggle")); - }); - - $("[data-toggle='resources-dropdown']").on("click", function() { - toggleDropdown($(this).attr("data-toggle")); - }); - - function toggleDropdown(menuToggle) { - var showMenuClass = "show-menu"; - var menuClass = "." + menuToggle + "-menu"; - - if ($(menuClass).hasClass(showMenuClass)) { - $(menuClass).removeClass(showMenuClass); - } else { - $("[data-toggle=" + menuToggle + "].show-menu").removeClass( - showMenuClass - ); - $(menuClass).addClass(showMenuClass); - } - } - } -}; - -},{}],6:[function(require,module,exports){ -window.mobileMenu = { - bind: function() { - $("[data-behavior='open-mobile-menu']").on('click', function(e) { - e.preventDefault(); - $(".mobile-main-menu").addClass("open"); - $("body").addClass('no-scroll'); - - mobileMenu.listenForResize(); - }); - - $("[data-behavior='close-mobile-menu']").on('click', function(e) { - e.preventDefault(); - mobileMenu.close(); - }); - }, - - listenForResize: function() { - $(window).on('resize.ForMobileMenu', function() { - if ($(this).width() > 768) { - mobileMenu.close(); - } - }); - }, - - close: function() { - $(".mobile-main-menu").removeClass("open"); - $("body").removeClass('no-scroll'); - $(window).off('resize.ForMobileMenu'); - } -}; - -},{}],7:[function(require,module,exports){ -window.mobileTOC = { - bind: function() { - $("[data-behavior='toggle-table-of-contents']").on("click", function(e) { - e.preventDefault(); - - var $parent = $(this).parent(); - - if ($parent.hasClass("is-open")) { - $parent.removeClass("is-open"); - $(".pytorch-left-menu").slideUp(200, function() { - $(this).css({display: ""}); - }); - } else { - $parent.addClass("is-open"); - $(".pytorch-left-menu").slideDown(200); - } - }); - } -} - -},{}],8:[function(require,module,exports){ -window.pytorchAnchors = { - bind: function() { - // Replace Sphinx-generated anchors with anchorjs ones - $(".headerlink").text(""); - - window.anchors.add(".pytorch-article .headerlink"); - - $(".anchorjs-link").each(function() { - var $headerLink = $(this).closest(".headerlink"); - var href = $headerLink.attr("href"); - var clone = this.outerHTML; - - $clone = $(clone).attr("href", href); - $headerLink.before($clone); - $headerLink.remove(); - }); - } -}; - -},{}],9:[function(require,module,exports){ -// Modified from https://stackoverflow.com/a/13067009 -// Going for a JS solution to scrolling to an anchor so we can benefit from -// less hacky css and smooth scrolling. - -window.scrollToAnchor = { - bind: function() { - var document = window.document; - var history = window.history; - var location = window.location - var HISTORY_SUPPORT = !!(history && history.pushState); - - var anchorScrolls = { - ANCHOR_REGEX: /^#[^ ]+$/, - offsetHeightPx: function() { - var OFFSET_HEIGHT_PADDING = 20; - // TODO: this is a little janky. We should try to not rely on JS for this - return utilities.headersHeight() + OFFSET_HEIGHT_PADDING; - }, - - /** - * Establish events, and fix initial scroll position if a hash is provided. - */ - init: function() { - this.scrollToCurrent(); - // This interferes with clicks below it, causing a double fire - // $(window).on('hashchange', $.proxy(this, 'scrollToCurrent')); - $('body').on('click', 'a', $.proxy(this, 'delegateAnchors')); - $('body').on('click', '#pytorch-right-menu li span', $.proxy(this, 'delegateSpans')); - }, - - /** - * Return the offset amount to deduct from the normal scroll position. - * Modify as appropriate to allow for dynamic calculations - */ - getFixedOffset: function() { - return this.offsetHeightPx(); - }, - - /** - * If the provided href is an anchor which resolves to an element on the - * page, scroll to it. - * @param {String} href - * @return {Boolean} - Was the href an anchor. - */ - scrollIfAnchor: function(href, pushToHistory) { - var match, anchorOffset; - - if(!this.ANCHOR_REGEX.test(href)) { - return false; - } - - match = document.getElementById(href.slice(1)); - - if(match) { - var anchorOffset = $(match).offset().top - this.getFixedOffset(); - - $('html, body').scrollTop(anchorOffset); - - // Add the state to history as-per normal anchor links - if(HISTORY_SUPPORT && pushToHistory) { - history.pushState({}, document.title, location.pathname + href); - } - } - - return !!match; - }, - - /** - * Attempt to scroll to the current location's hash. - */ - scrollToCurrent: function(e) { - if(this.scrollIfAnchor(window.location.hash) && e) { - e.preventDefault(); - } - }, - - delegateSpans: function(e) { - var elem = utilities.closest(e.target, "a"); - - if(this.scrollIfAnchor(elem.getAttribute('href'), true)) { - e.preventDefault(); - } - }, - - /** - * If the click event's target was an anchor, fix the scroll position. - */ - delegateAnchors: function(e) { - var elem = e.target; - - if(this.scrollIfAnchor(elem.getAttribute('href'), true)) { - e.preventDefault(); - } - } - }; - - $(document).ready($.proxy(anchorScrolls, 'init')); - } -}; - -},{}],10:[function(require,module,exports){ -window.sideMenus = { - rightMenuIsOnScreen: function() { - return document.getElementById("pytorch-content-right").offsetParent !== null; - }, - - isFixedToBottom: false, - - bind: function() { - sideMenus.handleLeftMenu(); - - var rightMenuLinks = document.querySelectorAll("#pytorch-right-menu li"); - var rightMenuHasLinks = rightMenuLinks.length > 1; - - if (!rightMenuHasLinks) { - for (var i = 0; i < rightMenuLinks.length; i++) { - rightMenuLinks[i].style.display = "none"; - } - } - - if (rightMenuHasLinks) { - // Don't show the Shortcuts menu title text unless there are menu items - document.getElementById("pytorch-shortcuts-wrapper").style.display = "block"; - - // We are hiding the titles of the pages in the right side menu but there are a few - // pages that include other pages in the right side menu (see 'torch.nn' in the docs) - // so if we exclude those it looks confusing. Here we add a 'title-link' class to these - // links so we can exclude them from normal right side menu link operations - var titleLinks = document.querySelectorAll( - "#pytorch-right-menu #pytorch-side-scroll-right \ - > ul > li > a.reference.internal" - ); - - for (var i = 0; i < titleLinks.length; i++) { - var link = titleLinks[i]; - - link.classList.add("title-link"); - - if ( - link.nextElementSibling && - link.nextElementSibling.tagName === "UL" && - link.nextElementSibling.children.length > 0 - ) { - link.classList.add("has-children"); - } - } - - // Add + expansion signifiers to normal right menu links that have sub menus - var menuLinks = document.querySelectorAll( - "#pytorch-right-menu ul li ul li a.reference.internal" - ); - - for (var i = 0; i < menuLinks.length; i++) { - if ( - menuLinks[i].nextElementSibling && - menuLinks[i].nextElementSibling.tagName === "UL" - ) { - menuLinks[i].classList.add("not-expanded"); - } - } - - // If a hash is present on page load recursively expand menu items leading to selected item - var linkWithHash = - document.querySelector( - "#pytorch-right-menu a[href=\"" + window.location.hash + "\"]" - ); - - if (linkWithHash) { - // Expand immediate sibling list if present - if ( - linkWithHash.nextElementSibling && - linkWithHash.nextElementSibling.tagName === "UL" && - linkWithHash.nextElementSibling.children.length > 0 - ) { - linkWithHash.nextElementSibling.style.display = "block"; - linkWithHash.classList.add("expanded"); - } - - // Expand ancestor lists if any - sideMenus.expandClosestUnexpandedParentList(linkWithHash); - } - - // Bind click events on right menu links - $("#pytorch-right-menu a.reference.internal").on("click", function() { - if (this.classList.contains("expanded")) { - this.nextElementSibling.style.display = "none"; - this.classList.remove("expanded"); - this.classList.add("not-expanded"); - } else if (this.classList.contains("not-expanded")) { - this.nextElementSibling.style.display = "block"; - this.classList.remove("not-expanded"); - this.classList.add("expanded"); - } - }); - - sideMenus.handleRightMenu(); - } - - $(window).on('resize scroll', function(e) { - sideMenus.handleNavBar(); - - sideMenus.handleLeftMenu(); - - if (sideMenus.rightMenuIsOnScreen()) { - sideMenus.handleRightMenu(); - } - }); - }, - - leftMenuIsFixed: function() { - return document.getElementById("pytorch-left-menu").classList.contains("make-fixed"); - }, - - handleNavBar: function() { - var mainHeaderHeight = document.getElementById('header-holder').offsetHeight; - - // If we are scrolled past the main navigation header fix the sub menu bar to top of page - if (utilities.scrollTop() >= mainHeaderHeight) { - document.getElementById("pytorch-left-menu").classList.add("make-fixed"); - document.getElementById("pytorch-page-level-bar").classList.add("left-menu-is-fixed"); - } else { - document.getElementById("pytorch-left-menu").classList.remove("make-fixed"); - document.getElementById("pytorch-page-level-bar").classList.remove("left-menu-is-fixed"); - } - }, - - expandClosestUnexpandedParentList: function (el) { - var closestParentList = utilities.closest(el, "ul"); - - if (closestParentList) { - var closestParentLink = closestParentList.previousElementSibling; - var closestParentLinkExists = closestParentLink && - closestParentLink.tagName === "A" && - closestParentLink.classList.contains("reference"); - - if (closestParentLinkExists) { - // Don't add expansion class to any title links - if (closestParentLink.classList.contains("title-link")) { - return; - } - - closestParentList.style.display = "block"; - closestParentLink.classList.remove("not-expanded"); - closestParentLink.classList.add("expanded"); - sideMenus.expandClosestUnexpandedParentList(closestParentLink); - } - } - }, - - handleLeftMenu: function () { - var windowHeight = utilities.windowHeight(); - var topOfFooterRelativeToWindow = document.getElementById("docs-tutorials-resources").getBoundingClientRect().top; - - if (topOfFooterRelativeToWindow >= windowHeight) { - document.getElementById("pytorch-left-menu").style.height = "100%"; - } else { - var howManyPixelsOfTheFooterAreInTheWindow = windowHeight - topOfFooterRelativeToWindow; - var leftMenuDifference = howManyPixelsOfTheFooterAreInTheWindow; - document.getElementById("pytorch-left-menu").style.height = (windowHeight - leftMenuDifference) + "px"; - } - }, - - handleRightMenu: function() { - var rightMenuWrapper = document.getElementById("pytorch-content-right"); - var rightMenu = document.getElementById("pytorch-right-menu"); - var rightMenuList = rightMenu.getElementsByTagName("ul")[0]; - var article = document.getElementById("pytorch-article"); - var articleHeight = article.offsetHeight; - var articleBottom = utilities.offset(article).top + articleHeight; - var mainHeaderHeight = document.getElementById('header-holder').offsetHeight; - - if (utilities.scrollTop() < mainHeaderHeight) { - rightMenuWrapper.style.height = "100%"; - rightMenu.style.top = 0; - rightMenu.classList.remove("scrolling-fixed"); - rightMenu.classList.remove("scrolling-absolute"); - } else { - if (rightMenu.classList.contains("scrolling-fixed")) { - var rightMenuBottom = - utilities.offset(rightMenuList).top + rightMenuList.offsetHeight; - - if (rightMenuBottom >= articleBottom) { - rightMenuWrapper.style.height = articleHeight + mainHeaderHeight + "px"; - rightMenu.style.top = utilities.scrollTop() - mainHeaderHeight + "px"; - rightMenu.classList.add("scrolling-absolute"); - rightMenu.classList.remove("scrolling-fixed"); - } - } else { - rightMenuWrapper.style.height = articleHeight + mainHeaderHeight + "px"; - rightMenu.style.top = - articleBottom - mainHeaderHeight - rightMenuList.offsetHeight + "px"; - rightMenu.classList.add("scrolling-absolute"); - } - - if (utilities.scrollTop() < articleBottom - rightMenuList.offsetHeight) { - rightMenuWrapper.style.height = "100%"; - rightMenu.style.top = ""; - rightMenu.classList.remove("scrolling-absolute"); - rightMenu.classList.add("scrolling-fixed"); - } - } - - var rightMenuSideScroll = document.getElementById("pytorch-side-scroll-right"); - var sideScrollFromWindowTop = rightMenuSideScroll.getBoundingClientRect().top; - - rightMenuSideScroll.style.height = utilities.windowHeight() - sideScrollFromWindowTop + "px"; - } -}; - -},{}],11:[function(require,module,exports){ -var jQuery = (typeof(window) != 'undefined') ? window.jQuery : require('jquery'); - -// Sphinx theme nav state -function ThemeNav () { - - var nav = { - navBar: null, - win: null, - winScroll: false, - winResize: false, - linkScroll: false, - winPosition: 0, - winHeight: null, - docHeight: null, - isRunning: false - }; - - nav.enable = function (withStickyNav) { - var self = this; - - // TODO this can likely be removed once the theme javascript is broken - // out from the RTD assets. This just ensures old projects that are - // calling `enable()` get the sticky menu on by default. All other cals - // to `enable` should include an argument for enabling the sticky menu. - if (typeof(withStickyNav) == 'undefined') { - withStickyNav = true; - } - - if (self.isRunning) { - // Only allow enabling nav logic once - return; - } - - self.isRunning = true; - jQuery(function ($) { - self.init($); - - self.reset(); - self.win.on('hashchange', self.reset); - - if (withStickyNav) { - // Set scroll monitor - self.win.on('scroll', function () { - if (!self.linkScroll) { - if (!self.winScroll) { - self.winScroll = true; - requestAnimationFrame(function() { self.onScroll(); }); - } - } - }); - } - - // Set resize monitor - self.win.on('resize', function () { - if (!self.winResize) { - self.winResize = true; - requestAnimationFrame(function() { self.onResize(); }); - } - }); - - self.onResize(); - }); - - }; - - // TODO remove this with a split in theme and Read the Docs JS logic as - // well, it's only here to support 0.3.0 installs of our theme. - nav.enableSticky = function() { - this.enable(true); - }; - - nav.init = function ($) { - var doc = $(document), - self = this; - - this.navBar = $('div.pytorch-side-scroll:first'); - this.win = $(window); - - // Set up javascript UX bits - $(document) - // Shift nav in mobile when clicking the menu. - .on('click', "[data-toggle='pytorch-left-menu-nav-top']", function() { - $("[data-toggle='wy-nav-shift']").toggleClass("shift"); - $("[data-toggle='rst-versions']").toggleClass("shift"); - }) - - // Nav menu link click operations - .on('click', ".pytorch-menu-vertical .current ul li a", function() { - var target = $(this); - // Close menu when you click a link. - $("[data-toggle='wy-nav-shift']").removeClass("shift"); - $("[data-toggle='rst-versions']").toggleClass("shift"); - // Handle dynamic display of l3 and l4 nav lists - self.toggleCurrent(target); - self.hashChange(); - }) - .on('click', "[data-toggle='rst-current-version']", function() { - $("[data-toggle='rst-versions']").toggleClass("shift-up"); - }) - - // Make tables responsive - $("table.docutils:not(.field-list,.footnote,.citation)") - .wrap("
"); - - // Add extra class to responsive tables that contain - // footnotes or citations so that we can target them for styling - $("table.docutils.footnote") - .wrap("
"); - $("table.docutils.citation") - .wrap("
"); - - // Add expand links to all parents of nested ul - $('.pytorch-menu-vertical ul').not('.simple').siblings('a').each(function () { - var link = $(this); - expand = $(''); - expand.on('click', function (ev) { - self.toggleCurrent(link); - ev.stopPropagation(); - return false; - }); - link.prepend(expand); - }); - }; - - nav.reset = function () { - // Get anchor from URL and open up nested nav - var anchor = encodeURI(window.location.hash) || '#'; - - try { - var vmenu = $('.pytorch-menu-vertical'); - var link = vmenu.find('[href="' + anchor + '"]'); - if (link.length === 0) { - // this link was not found in the sidebar. - // Find associated id element, then its closest section - // in the document and try with that one. - var id_elt = $('.document [id="' + anchor.substring(1) + '"]'); - var closest_section = id_elt.closest('div.section'); - link = vmenu.find('[href="#' + closest_section.attr("id") + '"]'); - if (link.length === 0) { - // still not found in the sidebar. fall back to main section - link = vmenu.find('[href="#"]'); - } - } - // If we found a matching link then reset current and re-apply - // otherwise retain the existing match - if (link.length > 0) { - $('.pytorch-menu-vertical .current').removeClass('current'); - link.addClass('current'); - link.closest('li.toctree-l1').addClass('current'); - link.closest('li.toctree-l1').parent().addClass('current'); - link.closest('li.toctree-l1').addClass('current'); - link.closest('li.toctree-l2').addClass('current'); - link.closest('li.toctree-l3').addClass('current'); - link.closest('li.toctree-l4').addClass('current'); - } - } - catch (err) { - console.log("Error expanding nav for anchor", err); - } - - }; - - nav.onScroll = function () { - this.winScroll = false; - var newWinPosition = this.win.scrollTop(), - winBottom = newWinPosition + this.winHeight, - navPosition = this.navBar.scrollTop(), - newNavPosition = navPosition + (newWinPosition - this.winPosition); - if (newWinPosition < 0 || winBottom > this.docHeight) { - return; - } - this.navBar.scrollTop(newNavPosition); - this.winPosition = newWinPosition; - }; - - nav.onResize = function () { - this.winResize = false; - this.winHeight = this.win.height(); - this.docHeight = $(document).height(); - }; - - nav.hashChange = function () { - this.linkScroll = true; - this.win.one('hashchange', function () { - this.linkScroll = false; - }); - }; - - nav.toggleCurrent = function (elem) { - var parent_li = elem.closest('li'); - parent_li.siblings('li.current').removeClass('current'); - parent_li.siblings().find('li.current').removeClass('current'); - parent_li.find('> ul li.current').removeClass('current'); - parent_li.toggleClass('current'); - } - - return nav; -}; - -module.exports.ThemeNav = ThemeNav(); - -if (typeof(window) != 'undefined') { - window.SphinxRtdTheme = { - Navigation: module.exports.ThemeNav, - // TODO remove this once static assets are split up between the theme - // and Read the Docs. For now, this patches 0.3.0 to be backwards - // compatible with a pre-0.3.0 layout.html - StickyNav: module.exports.ThemeNav, - }; -} - - -// requestAnimationFrame polyfill by Erik Möller. fixes from Paul Irish and Tino Zijdel -// https://gist.github.com/paulirish/1579671 -// MIT license - -(function() { - var lastTime = 0; - var vendors = ['ms', 'moz', 'webkit', 'o']; - for(var x = 0; x < vendors.length && !window.requestAnimationFrame; ++x) { - window.requestAnimationFrame = window[vendors[x]+'RequestAnimationFrame']; - window.cancelAnimationFrame = window[vendors[x]+'CancelAnimationFrame'] - || window[vendors[x]+'CancelRequestAnimationFrame']; - } - - if (!window.requestAnimationFrame) - window.requestAnimationFrame = function(callback, element) { - var currTime = new Date().getTime(); - var timeToCall = Math.max(0, 16 - (currTime - lastTime)); - var id = window.setTimeout(function() { callback(currTime + timeToCall); }, - timeToCall); - lastTime = currTime + timeToCall; - return id; - }; - - if (!window.cancelAnimationFrame) - window.cancelAnimationFrame = function(id) { - clearTimeout(id); - }; -}()); - -$(".sphx-glr-thumbcontainer").removeAttr("tooltip"); -$("table").removeAttr("border"); - -// This code replaces the default sphinx gallery download buttons -// with the 3 download buttons at the top of the page - -var downloadNote = $(".sphx-glr-download-link-note.admonition.note"); -if (downloadNote.length >= 1) { - var tutorialUrlArray = $("#tutorial-type").text().split('/'); - tutorialUrlArray[0] = tutorialUrlArray[0] + "/sphinx-tutorials" - - var githubLink = "https://github.com/pytorch/rl/blob/main/" + tutorialUrlArray.join("/") + ".py", - notebookLink = $(".reference.download")[1].href, - notebookDownloadPath = notebookLink.split('_downloads')[1], - colabLink = "https://colab.research.google.com/github/pytorch/rl/blob/gh-pages/_downloads" + notebookDownloadPath; - - $("#google-colab-link").wrap(""); - $("#download-notebook-link").wrap(""); - $("#github-view-link").wrap(""); -} else { - $(".pytorch-call-to-action-links").hide(); -} - -//This code handles the Expand/Hide toggle for the Docs/Tutorials left nav items - -$(document).ready(function() { - var caption = "#pytorch-left-menu p.caption"; - var collapseAdded = $(this).not("checked"); - $(caption).each(function () { - var menuName = this.innerText.replace(/[^\w\s]/gi, "").trim(); - $(this).find("span").addClass("checked"); - if (collapsedSections.includes(menuName) == true && collapseAdded && sessionStorage.getItem(menuName) !== "expand" || sessionStorage.getItem(menuName) == "collapse") { - $(this.firstChild).after("[ + ]"); - $(this.firstChild).after("[ - ]"); - $(this).next("ul").hide(); - } else if (collapsedSections.includes(menuName) == false && collapseAdded || sessionStorage.getItem(menuName) == "expand") { - $(this.firstChild).after("[ + ]"); - $(this.firstChild).after("[ - ]"); - } - }); - - $(".expand-menu").on("click", function () { - $(this).prev(".hide-menu").toggle(); - $(this).parent().next("ul").toggle(); - var menuName = $(this).parent().text().replace(/[^\w\s]/gi, "").trim(); - if (sessionStorage.getItem(menuName) == "collapse") { - sessionStorage.removeItem(menuName); - } - sessionStorage.setItem(menuName, "expand"); - toggleList(this); - }); - - $(".hide-menu").on("click", function () { - $(this).next(".expand-menu").toggle(); - $(this).parent().next("ul").toggle(); - var menuName = $(this).parent().text().replace(/[^\w\s]/gi, "").trim(); - if (sessionStorage.getItem(menuName) == "expand") { - sessionStorage.removeItem(menuName); - } - sessionStorage.setItem(menuName, "collapse"); - toggleList(this); - }); - - function toggleList(menuCommand) { - $(menuCommand).toggle(); - } -}); - -// Build an array from each tag that's present - -var tagList = $(".tutorials-card-container").map(function() { - return $(this).data("tags").split(",").map(function(item) { - return item.trim(); - }); -}).get(); - -function unique(value, index, self) { - return self.indexOf(value) == index && value != "" - } - -// Only return unique tags - -var tags = tagList.sort().filter(unique); - -// Add filter buttons to the top of the page for each tag - -function createTagMenu() { - tags.forEach(function(item){ - $(".tutorial-filter-menu").append("
" + item + "
") - }) -}; - -createTagMenu(); - -// Remove hyphens if they are present in the filter buttons - -$(".tags").each(function(){ - var tags = $(this).text().split(","); - tags.forEach(function(tag, i ) { - tags[i] = tags[i].replace(/-/, ' ') - }) - $(this).html(tags.join(", ")); -}); - -// Remove hyphens if they are present in the card body - -$(".tutorial-filter").each(function(){ - var tag = $(this).text(); - $(this).html(tag.replace(/-/, ' ')) -}) - -// Remove any empty p tags that Sphinx adds - -$("#tutorial-cards p").each(function(index, item) { - if(!$(item).text().trim()) { - $(item).remove(); - } -}); - -// Jump back to top on pagination click - -$(document).on("click", ".page", function() { - $('html, body').animate( - {scrollTop: $("#dropdown-filter-tags").position().top}, - 'slow' - ); -}); - -var link = $("a[href='intermediate/speech_command_recognition_with_torchaudio.html']"); - -if (link.text() == "SyntaxError") { - console.log("There is an issue with the intermediate/speech_command_recognition_with_torchaudio.html menu item."); - link.text("Speech Command Recognition with torchaudio"); -} - -$(".stars-outer > i").hover(function() { - $(this).prevAll().addBack().toggleClass("fas star-fill"); -}); - -$(".stars-outer > i").on("click", function() { - $(this).prevAll().each(function() { - $(this).addBack().addClass("fas star-fill"); - }); - - $(".stars-outer > i").each(function() { - $(this).unbind("mouseenter mouseleave").css({ - "pointer-events": "none" - }); - }); -}) - -$("#pytorch-side-scroll-right li a").on("click", function (e) { - var href = $(this).attr("href"); - $('html, body').stop().animate({ - scrollTop: $(href).offset().top - 100 - }, 850); - e.preventDefault; -}); - -var lastId, - topMenu = $("#pytorch-side-scroll-right"), - topMenuHeight = topMenu.outerHeight() + 1, - // All sidenav items - menuItems = topMenu.find("a"), - // Anchors for menu items - scrollItems = menuItems.map(function () { - var item = $(this).attr("href"); - if (item.length) { - return item; - } - }); - -$(window).scroll(function () { - var fromTop = $(this).scrollTop() + topMenuHeight; - var article = ".section"; - - $(article).each(function (i) { - var offsetScroll = $(this).offset().top - $(window).scrollTop(); - if ( - offsetScroll <= topMenuHeight + 200 && - offsetScroll >= topMenuHeight - 200 && - scrollItems[i] == "#" + $(this).attr("id") && - $(".hidden:visible") - ) { - $(menuItems).removeClass("side-scroll-highlight"); - $(menuItems[i]).addClass("side-scroll-highlight"); - } - }); -}); - - -},{"jquery":"jquery"}],"pytorch-sphinx-theme":[function(require,module,exports){ -require=(function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c="function"==typeof require&&require;if(!f&&c)return c(i,!0);if(u)return u(i,!0);var a=new Error("Cannot find module '"+i+"'");throw a.code="MODULE_NOT_FOUND",a}var p=n[i]={exports:{}};e[i][0].call(p.exports,function(r){var n=e[i][1][r];return o(n||r)},p,p.exports,r,e,n,t)}return n[i].exports}for(var u="function"==typeof require&&require,i=0;i wait) { - if (timeout) { - clearTimeout(timeout); - timeout = null; - } - previous = now; - result = func.apply(context, args); - if (!timeout) context = args = null; - } else if (!timeout && options.trailing !== false) { - timeout = setTimeout(later, remaining); - } - return result; - }; - }, - - closest: function (el, selector) { - var matchesFn; - - // find vendor prefix - ['matches','webkitMatchesSelector','mozMatchesSelector','msMatchesSelector','oMatchesSelector'].some(function(fn) { - if (typeof document.body[fn] == 'function') { - matchesFn = fn; - return true; - } - return false; - }); - - var parent; - - // traverse parents - while (el) { - parent = el.parentElement; - if (parent && parent[matchesFn](selector)) { - return parent; - } - el = parent; - } - - return null; - }, - - // Modified from https://stackoverflow.com/a/18953277 - offset: function(elem) { - if (!elem) { - return; - } - - rect = elem.getBoundingClientRect(); - - // Make sure element is not hidden (display: none) or disconnected - if (rect.width || rect.height || elem.getClientRects().length) { - var doc = elem.ownerDocument; - var docElem = doc.documentElement; - - return { - top: rect.top + window.pageYOffset - docElem.clientTop, - left: rect.left + window.pageXOffset - docElem.clientLeft - }; - } - }, - - headersHeight: function() { - if (document.getElementById("pytorch-left-menu").classList.contains("make-fixed")) { - return document.getElementById("pytorch-page-level-bar").offsetHeight; - } else { - return document.getElementById("header-holder").offsetHeight + - document.getElementById("pytorch-page-level-bar").offsetHeight; - } - }, - - windowHeight: function() { - return window.innerHeight || - document.documentElement.clientHeight || - document.body.clientHeight; - } -} - -},{}],2:[function(require,module,exports){ -var cookieBanner = { - init: function() { - cookieBanner.bind(); - - var cookieExists = cookieBanner.cookieExists(); - - if (!cookieExists) { - cookieBanner.setCookie(); - cookieBanner.showCookieNotice(); - } - }, - - bind: function() { - $(".close-button").on("click", cookieBanner.hideCookieNotice); - }, - - cookieExists: function() { - var cookie = localStorage.getItem("returningPytorchUser"); - - if (cookie) { - return true; - } else { - return false; - } - }, - - setCookie: function() { - localStorage.setItem("returningPytorchUser", true); - }, - - showCookieNotice: function() { - $(".cookie-banner-wrapper").addClass("is-visible"); - }, - - hideCookieNotice: function() { - $(".cookie-banner-wrapper").removeClass("is-visible"); - } -}; - -$(function() { - cookieBanner.init(); -}); - -},{}],3:[function(require,module,exports){ -window.filterTags = { - bind: function() { - var options = { - valueNames: [{ data: ["tags"] }], - page: "6", - pagination: true - }; - - var tutorialList = new List("tutorial-cards", options); - - function filterSelectedTags(cardTags, selectedTags) { - return cardTags.some(function(tag) { - return selectedTags.some(function(selectedTag) { - return selectedTag == tag; - }); - }); - } - - function updateList() { - var selectedTags = []; - - $(".selected").each(function() { - selectedTags.push($(this).data("tag")); - }); - - tutorialList.filter(function(item) { - var cardTags; - - if (item.values().tags == null) { - cardTags = [""]; - } else { - cardTags = item.values().tags.split(","); - } - - if (selectedTags.length == 0) { - return true; - } else { - return filterSelectedTags(cardTags, selectedTags); - } - }); - } - - $(".filter-btn").on("click", function() { - if ($(this).data("tag") == "all") { - $(this).addClass("all-tag-selected"); - $(".filter").removeClass("selected"); - } else { - $(this).toggleClass("selected"); - $("[data-tag='all']").removeClass("all-tag-selected"); - } - - // If no tags are selected then highlight the 'All' tag - - if (!$(".selected")[0]) { - $("[data-tag='all']").addClass("all-tag-selected"); - } - - updateList(); - }); - } -}; - -},{}],4:[function(require,module,exports){ -// Modified from https://stackoverflow.com/a/32396543 -window.highlightNavigation = { - navigationListItems: document.querySelectorAll("#pytorch-right-menu li"), - sections: document.querySelectorAll(".pytorch-article .section"), - sectionIdTonavigationLink: {}, - - bind: function() { - if (!sideMenus.displayRightMenu) { - return; - }; - - for (var i = 0; i < highlightNavigation.sections.length; i++) { - var id = highlightNavigation.sections[i].id; - highlightNavigation.sectionIdTonavigationLink[id] = - document.querySelectorAll('#pytorch-right-menu li a[href="#' + id + '"]')[0]; - } - - $(window).scroll(utilities.throttle(highlightNavigation.highlight, 100)); - }, - - highlight: function() { - var rightMenu = document.getElementById("pytorch-right-menu"); - - // If right menu is not on the screen don't bother - if (rightMenu.offsetWidth === 0 && rightMenu.offsetHeight === 0) { - return; - } - - var scrollPosition = utilities.scrollTop(); - var OFFSET_TOP_PADDING = 25; - var offset = document.getElementById("header-holder").offsetHeight + - document.getElementById("pytorch-page-level-bar").offsetHeight + - OFFSET_TOP_PADDING; - - var sections = highlightNavigation.sections; - - for (var i = (sections.length - 1); i >= 0; i--) { - var currentSection = sections[i]; - var sectionTop = utilities.offset(currentSection).top; - - if (scrollPosition >= sectionTop - offset) { - var navigationLink = highlightNavigation.sectionIdTonavigationLink[currentSection.id]; - var navigationListItem = utilities.closest(navigationLink, "li"); - - if (navigationListItem && !navigationListItem.classList.contains("active")) { - for (var i = 0; i < highlightNavigation.navigationListItems.length; i++) { - var el = highlightNavigation.navigationListItems[i]; - if (el.classList.contains("active")) { - el.classList.remove("active"); - } - } - - navigationListItem.classList.add("active"); - - // Scroll to active item. Not a requested feature but we could revive it. Needs work. - - // var menuTop = $("#pytorch-right-menu").position().top; - // var itemTop = navigationListItem.getBoundingClientRect().top; - // var TOP_PADDING = 20 - // var newActiveTop = $("#pytorch-side-scroll-right").scrollTop() + itemTop - menuTop - TOP_PADDING; - - // $("#pytorch-side-scroll-right").animate({ - // scrollTop: newActiveTop - // }, 100); - } - - break; - } - } - } -}; - -},{}],5:[function(require,module,exports){ -window.mainMenuDropdown = { - bind: function() { - $("[data-toggle='ecosystem-dropdown']").on("click", function() { - toggleDropdown($(this).attr("data-toggle")); - }); - - $("[data-toggle='resources-dropdown']").on("click", function() { - toggleDropdown($(this).attr("data-toggle")); - }); - - function toggleDropdown(menuToggle) { - var showMenuClass = "show-menu"; - var menuClass = "." + menuToggle + "-menu"; - - if ($(menuClass).hasClass(showMenuClass)) { - $(menuClass).removeClass(showMenuClass); - } else { - $("[data-toggle=" + menuToggle + "].show-menu").removeClass( - showMenuClass - ); - $(menuClass).addClass(showMenuClass); - } - } - } -}; - -},{}],6:[function(require,module,exports){ -window.mobileMenu = { - bind: function() { - $("[data-behavior='open-mobile-menu']").on('click', function(e) { - e.preventDefault(); - $(".mobile-main-menu").addClass("open"); - $("body").addClass('no-scroll'); - - mobileMenu.listenForResize(); - }); - - $("[data-behavior='close-mobile-menu']").on('click', function(e) { - e.preventDefault(); - mobileMenu.close(); - }); - }, - - listenForResize: function() { - $(window).on('resize.ForMobileMenu', function() { - if ($(this).width() > 768) { - mobileMenu.close(); - } - }); - }, - - close: function() { - $(".mobile-main-menu").removeClass("open"); - $("body").removeClass('no-scroll'); - $(window).off('resize.ForMobileMenu'); - } -}; - -},{}],7:[function(require,module,exports){ -window.mobileTOC = { - bind: function() { - $("[data-behavior='toggle-table-of-contents']").on("click", function(e) { - e.preventDefault(); - - var $parent = $(this).parent(); - - if ($parent.hasClass("is-open")) { - $parent.removeClass("is-open"); - $(".pytorch-left-menu").slideUp(200, function() { - $(this).css({display: ""}); - }); - } else { - $parent.addClass("is-open"); - $(".pytorch-left-menu").slideDown(200); - } - }); - } -} - -},{}],8:[function(require,module,exports){ -window.pytorchAnchors = { - bind: function() { - // Replace Sphinx-generated anchors with anchorjs ones - $(".headerlink").text(""); - - window.anchors.add(".pytorch-article .headerlink"); - - $(".anchorjs-link").each(function() { - var $headerLink = $(this).closest(".headerlink"); - var href = $headerLink.attr("href"); - var clone = this.outerHTML; - - $clone = $(clone).attr("href", href); - $headerLink.before($clone); - $headerLink.remove(); - }); - } -}; - -},{}],9:[function(require,module,exports){ -// Modified from https://stackoverflow.com/a/13067009 -// Going for a JS solution to scrolling to an anchor so we can benefit from -// less hacky css and smooth scrolling. - -window.scrollToAnchor = { - bind: function() { - var document = window.document; - var history = window.history; - var location = window.location - var HISTORY_SUPPORT = !!(history && history.pushState); - - var anchorScrolls = { - ANCHOR_REGEX: /^#[^ ]+$/, - offsetHeightPx: function() { - var OFFSET_HEIGHT_PADDING = 20; - // TODO: this is a little janky. We should try to not rely on JS for this - return utilities.headersHeight() + OFFSET_HEIGHT_PADDING; - }, - - /** - * Establish events, and fix initial scroll position if a hash is provided. - */ - init: function() { - this.scrollToCurrent(); - // This interferes with clicks below it, causing a double fire - // $(window).on('hashchange', $.proxy(this, 'scrollToCurrent')); - $('body').on('click', 'a', $.proxy(this, 'delegateAnchors')); - $('body').on('click', '#pytorch-right-menu li span', $.proxy(this, 'delegateSpans')); - }, - - /** - * Return the offset amount to deduct from the normal scroll position. - * Modify as appropriate to allow for dynamic calculations - */ - getFixedOffset: function() { - return this.offsetHeightPx(); - }, - - /** - * If the provided href is an anchor which resolves to an element on the - * page, scroll to it. - * @param {String} href - * @return {Boolean} - Was the href an anchor. - */ - scrollIfAnchor: function(href, pushToHistory) { - var match, anchorOffset; - - if(!this.ANCHOR_REGEX.test(href)) { - return false; - } - - match = document.getElementById(href.slice(1)); - - if(match) { - var anchorOffset = $(match).offset().top - this.getFixedOffset(); - - $('html, body').scrollTop(anchorOffset); - - // Add the state to history as-per normal anchor links - if(HISTORY_SUPPORT && pushToHistory) { - history.pushState({}, document.title, location.pathname + href); - } - } - - return !!match; - }, - - /** - * Attempt to scroll to the current location's hash. - */ - scrollToCurrent: function(e) { - if(this.scrollIfAnchor(window.location.hash) && e) { - e.preventDefault(); - } - }, - - delegateSpans: function(e) { - var elem = utilities.closest(e.target, "a"); - - if(this.scrollIfAnchor(elem.getAttribute('href'), true)) { - e.preventDefault(); - } - }, - - /** - * If the click event's target was an anchor, fix the scroll position. - */ - delegateAnchors: function(e) { - var elem = e.target; - - if(this.scrollIfAnchor(elem.getAttribute('href'), true)) { - e.preventDefault(); - } - } - }; - - $(document).ready($.proxy(anchorScrolls, 'init')); - } -}; - -},{}],10:[function(require,module,exports){ -window.sideMenus = { - rightMenuIsOnScreen: function() { - return document.getElementById("pytorch-content-right").offsetParent !== null; - }, - - isFixedToBottom: false, - - bind: function() { - sideMenus.handleLeftMenu(); - - var rightMenuLinks = document.querySelectorAll("#pytorch-right-menu li"); - var rightMenuHasLinks = rightMenuLinks.length > 1; - - if (!rightMenuHasLinks) { - for (var i = 0; i < rightMenuLinks.length; i++) { - rightMenuLinks[i].style.display = "none"; - } - } - - if (rightMenuHasLinks) { - // Don't show the Shortcuts menu title text unless there are menu items - document.getElementById("pytorch-shortcuts-wrapper").style.display = "block"; - - // We are hiding the titles of the pages in the right side menu but there are a few - // pages that include other pages in the right side menu (see 'torch.nn' in the docs) - // so if we exclude those it looks confusing. Here we add a 'title-link' class to these - // links so we can exclude them from normal right side menu link operations - var titleLinks = document.querySelectorAll( - "#pytorch-right-menu #pytorch-side-scroll-right \ - > ul > li > a.reference.internal" - ); - - for (var i = 0; i < titleLinks.length; i++) { - var link = titleLinks[i]; - - link.classList.add("title-link"); - - if ( - link.nextElementSibling && - link.nextElementSibling.tagName === "UL" && - link.nextElementSibling.children.length > 0 - ) { - link.classList.add("has-children"); - } - } - - // Add + expansion signifiers to normal right menu links that have sub menus - var menuLinks = document.querySelectorAll( - "#pytorch-right-menu ul li ul li a.reference.internal" - ); - - for (var i = 0; i < menuLinks.length; i++) { - if ( - menuLinks[i].nextElementSibling && - menuLinks[i].nextElementSibling.tagName === "UL" - ) { - menuLinks[i].classList.add("not-expanded"); - } - } - - // If a hash is present on page load recursively expand menu items leading to selected item - var linkWithHash = - document.querySelector( - "#pytorch-right-menu a[href=\"" + window.location.hash + "\"]" - ); - - if (linkWithHash) { - // Expand immediate sibling list if present - if ( - linkWithHash.nextElementSibling && - linkWithHash.nextElementSibling.tagName === "UL" && - linkWithHash.nextElementSibling.children.length > 0 - ) { - linkWithHash.nextElementSibling.style.display = "block"; - linkWithHash.classList.add("expanded"); - } - - // Expand ancestor lists if any - sideMenus.expandClosestUnexpandedParentList(linkWithHash); - } - - // Bind click events on right menu links - $("#pytorch-right-menu a.reference.internal").on("click", function() { - if (this.classList.contains("expanded")) { - this.nextElementSibling.style.display = "none"; - this.classList.remove("expanded"); - this.classList.add("not-expanded"); - } else if (this.classList.contains("not-expanded")) { - this.nextElementSibling.style.display = "block"; - this.classList.remove("not-expanded"); - this.classList.add("expanded"); - } - }); - - sideMenus.handleRightMenu(); - } - - $(window).on('resize scroll', function(e) { - sideMenus.handleNavBar(); - - sideMenus.handleLeftMenu(); - - if (sideMenus.rightMenuIsOnScreen()) { - sideMenus.handleRightMenu(); - } - }); - }, - - leftMenuIsFixed: function() { - return document.getElementById("pytorch-left-menu").classList.contains("make-fixed"); - }, - - handleNavBar: function() { - var mainHeaderHeight = document.getElementById('header-holder').offsetHeight; - - // If we are scrolled past the main navigation header fix the sub menu bar to top of page - if (utilities.scrollTop() >= mainHeaderHeight) { - document.getElementById("pytorch-left-menu").classList.add("make-fixed"); - document.getElementById("pytorch-page-level-bar").classList.add("left-menu-is-fixed"); - } else { - document.getElementById("pytorch-left-menu").classList.remove("make-fixed"); - document.getElementById("pytorch-page-level-bar").classList.remove("left-menu-is-fixed"); - } - }, - - expandClosestUnexpandedParentList: function (el) { - var closestParentList = utilities.closest(el, "ul"); - - if (closestParentList) { - var closestParentLink = closestParentList.previousElementSibling; - var closestParentLinkExists = closestParentLink && - closestParentLink.tagName === "A" && - closestParentLink.classList.contains("reference"); - - if (closestParentLinkExists) { - // Don't add expansion class to any title links - if (closestParentLink.classList.contains("title-link")) { - return; - } - - closestParentList.style.display = "block"; - closestParentLink.classList.remove("not-expanded"); - closestParentLink.classList.add("expanded"); - sideMenus.expandClosestUnexpandedParentList(closestParentLink); - } - } - }, - - handleLeftMenu: function () { - var windowHeight = utilities.windowHeight(); - var topOfFooterRelativeToWindow = document.getElementById("docs-tutorials-resources").getBoundingClientRect().top; - - if (topOfFooterRelativeToWindow >= windowHeight) { - document.getElementById("pytorch-left-menu").style.height = "100%"; - } else { - var howManyPixelsOfTheFooterAreInTheWindow = windowHeight - topOfFooterRelativeToWindow; - var leftMenuDifference = howManyPixelsOfTheFooterAreInTheWindow; - document.getElementById("pytorch-left-menu").style.height = (windowHeight - leftMenuDifference) + "px"; - } - }, - - handleRightMenu: function() { - var rightMenuWrapper = document.getElementById("pytorch-content-right"); - var rightMenu = document.getElementById("pytorch-right-menu"); - var rightMenuList = rightMenu.getElementsByTagName("ul")[0]; - var article = document.getElementById("pytorch-article"); - var articleHeight = article.offsetHeight; - var articleBottom = utilities.offset(article).top + articleHeight; - var mainHeaderHeight = document.getElementById('header-holder').offsetHeight; - - if (utilities.scrollTop() < mainHeaderHeight) { - rightMenuWrapper.style.height = "100%"; - rightMenu.style.top = 0; - rightMenu.classList.remove("scrolling-fixed"); - rightMenu.classList.remove("scrolling-absolute"); - } else { - if (rightMenu.classList.contains("scrolling-fixed")) { - var rightMenuBottom = - utilities.offset(rightMenuList).top + rightMenuList.offsetHeight; - - if (rightMenuBottom >= articleBottom) { - rightMenuWrapper.style.height = articleHeight + mainHeaderHeight + "px"; - rightMenu.style.top = utilities.scrollTop() - mainHeaderHeight + "px"; - rightMenu.classList.add("scrolling-absolute"); - rightMenu.classList.remove("scrolling-fixed"); - } - } else { - rightMenuWrapper.style.height = articleHeight + mainHeaderHeight + "px"; - rightMenu.style.top = - articleBottom - mainHeaderHeight - rightMenuList.offsetHeight + "px"; - rightMenu.classList.add("scrolling-absolute"); - } - - if (utilities.scrollTop() < articleBottom - rightMenuList.offsetHeight) { - rightMenuWrapper.style.height = "100%"; - rightMenu.style.top = ""; - rightMenu.classList.remove("scrolling-absolute"); - rightMenu.classList.add("scrolling-fixed"); - } - } - - var rightMenuSideScroll = document.getElementById("pytorch-side-scroll-right"); - var sideScrollFromWindowTop = rightMenuSideScroll.getBoundingClientRect().top; - - rightMenuSideScroll.style.height = utilities.windowHeight() - sideScrollFromWindowTop + "px"; - } -}; - -},{}],11:[function(require,module,exports){ -var jQuery = (typeof(window) != 'undefined') ? window.jQuery : require('jquery'); - -// Sphinx theme nav state -function ThemeNav () { - - var nav = { - navBar: null, - win: null, - winScroll: false, - winResize: false, - linkScroll: false, - winPosition: 0, - winHeight: null, - docHeight: null, - isRunning: false - }; - - nav.enable = function (withStickyNav) { - var self = this; - - // TODO this can likely be removed once the theme javascript is broken - // out from the RTD assets. This just ensures old projects that are - // calling `enable()` get the sticky menu on by default. All other cals - // to `enable` should include an argument for enabling the sticky menu. - if (typeof(withStickyNav) == 'undefined') { - withStickyNav = true; - } - - if (self.isRunning) { - // Only allow enabling nav logic once - return; - } - - self.isRunning = true; - jQuery(function ($) { - self.init($); - - self.reset(); - self.win.on('hashchange', self.reset); - - if (withStickyNav) { - // Set scroll monitor - self.win.on('scroll', function () { - if (!self.linkScroll) { - if (!self.winScroll) { - self.winScroll = true; - requestAnimationFrame(function() { self.onScroll(); }); - } - } - }); - } - - // Set resize monitor - self.win.on('resize', function () { - if (!self.winResize) { - self.winResize = true; - requestAnimationFrame(function() { self.onResize(); }); - } - }); - - self.onResize(); - }); - - }; - - // TODO remove this with a split in theme and Read the Docs JS logic as - // well, it's only here to support 0.3.0 installs of our theme. - nav.enableSticky = function() { - this.enable(true); - }; - - nav.init = function ($) { - var doc = $(document), - self = this; - - this.navBar = $('div.pytorch-side-scroll:first'); - this.win = $(window); - - // Set up javascript UX bits - $(document) - // Shift nav in mobile when clicking the menu. - .on('click', "[data-toggle='pytorch-left-menu-nav-top']", function() { - $("[data-toggle='wy-nav-shift']").toggleClass("shift"); - $("[data-toggle='rst-versions']").toggleClass("shift"); - }) - - // Nav menu link click operations - .on('click', ".pytorch-menu-vertical .current ul li a", function() { - var target = $(this); - // Close menu when you click a link. - $("[data-toggle='wy-nav-shift']").removeClass("shift"); - $("[data-toggle='rst-versions']").toggleClass("shift"); - // Handle dynamic display of l3 and l4 nav lists - self.toggleCurrent(target); - self.hashChange(); - }) - .on('click', "[data-toggle='rst-current-version']", function() { - $("[data-toggle='rst-versions']").toggleClass("shift-up"); - }) - - // Make tables responsive - $("table.docutils:not(.field-list,.footnote,.citation)") - .wrap("
"); - - // Add extra class to responsive tables that contain - // footnotes or citations so that we can target them for styling - $("table.docutils.footnote") - .wrap("
"); - $("table.docutils.citation") - .wrap("
"); - - // Add expand links to all parents of nested ul - $('.pytorch-menu-vertical ul').not('.simple').siblings('a').each(function () { - var link = $(this); - expand = $(''); - expand.on('click', function (ev) { - self.toggleCurrent(link); - ev.stopPropagation(); - return false; - }); - link.prepend(expand); - }); - }; - - nav.reset = function () { - // Get anchor from URL and open up nested nav - var anchor = encodeURI(window.location.hash) || '#'; - - try { - var vmenu = $('.pytorch-menu-vertical'); - var link = vmenu.find('[href="' + anchor + '"]'); - if (link.length === 0) { - // this link was not found in the sidebar. - // Find associated id element, then its closest section - // in the document and try with that one. - var id_elt = $('.document [id="' + anchor.substring(1) + '"]'); - var closest_section = id_elt.closest('div.section'); - link = vmenu.find('[href="#' + closest_section.attr("id") + '"]'); - if (link.length === 0) { - // still not found in the sidebar. fall back to main section - link = vmenu.find('[href="#"]'); - } - } - // If we found a matching link then reset current and re-apply - // otherwise retain the existing match - if (link.length > 0) { - $('.pytorch-menu-vertical .current').removeClass('current'); - link.addClass('current'); - link.closest('li.toctree-l1').addClass('current'); - link.closest('li.toctree-l1').parent().addClass('current'); - link.closest('li.toctree-l1').addClass('current'); - link.closest('li.toctree-l2').addClass('current'); - link.closest('li.toctree-l3').addClass('current'); - link.closest('li.toctree-l4').addClass('current'); - } - } - catch (err) { - console.log("Error expanding nav for anchor", err); - } - - }; - - nav.onScroll = function () { - this.winScroll = false; - var newWinPosition = this.win.scrollTop(), - winBottom = newWinPosition + this.winHeight, - navPosition = this.navBar.scrollTop(), - newNavPosition = navPosition + (newWinPosition - this.winPosition); - if (newWinPosition < 0 || winBottom > this.docHeight) { - return; - } - this.navBar.scrollTop(newNavPosition); - this.winPosition = newWinPosition; - }; - - nav.onResize = function () { - this.winResize = false; - this.winHeight = this.win.height(); - this.docHeight = $(document).height(); - }; - - nav.hashChange = function () { - this.linkScroll = true; - this.win.one('hashchange', function () { - this.linkScroll = false; - }); - }; - - nav.toggleCurrent = function (elem) { - var parent_li = elem.closest('li'); - parent_li.siblings('li.current').removeClass('current'); - parent_li.siblings().find('li.current').removeClass('current'); - parent_li.find('> ul li.current').removeClass('current'); - parent_li.toggleClass('current'); - } - - return nav; -}; - -module.exports.ThemeNav = ThemeNav(); - -if (typeof(window) != 'undefined') { - window.SphinxRtdTheme = { - Navigation: module.exports.ThemeNav, - // TODO remove this once static assets are split up between the theme - // and Read the Docs. For now, this patches 0.3.0 to be backwards - // compatible with a pre-0.3.0 layout.html - StickyNav: module.exports.ThemeNav, - }; -} - - -// requestAnimationFrame polyfill by Erik Möller. fixes from Paul Irish and Tino Zijdel -// https://gist.github.com/paulirish/1579671 -// MIT license - -(function() { - var lastTime = 0; - var vendors = ['ms', 'moz', 'webkit', 'o']; - for(var x = 0; x < vendors.length && !window.requestAnimationFrame; ++x) { - window.requestAnimationFrame = window[vendors[x]+'RequestAnimationFrame']; - window.cancelAnimationFrame = window[vendors[x]+'CancelAnimationFrame'] - || window[vendors[x]+'CancelRequestAnimationFrame']; - } - - if (!window.requestAnimationFrame) - window.requestAnimationFrame = function(callback, element) { - var currTime = new Date().getTime(); - var timeToCall = Math.max(0, 16 - (currTime - lastTime)); - var id = window.setTimeout(function() { callback(currTime + timeToCall); }, - timeToCall); - lastTime = currTime + timeToCall; - return id; - }; - - if (!window.cancelAnimationFrame) - window.cancelAnimationFrame = function(id) { - clearTimeout(id); - }; -}()); - -$(".sphx-glr-thumbcontainer").removeAttr("tooltip"); -$("table").removeAttr("border"); - -// This code replaces the default sphinx gallery download buttons -// with the 3 download buttons at the top of the page - -var downloadNote = $(".sphx-glr-download-link-note.admonition.note"); -if (downloadNote.length >= 1) { - var tutorialUrlArray = $("#tutorial-type").text().split('/'); - tutorialUrlArray[0] = tutorialUrlArray[0] + "/sphinx-tutorials" - - var githubLink = "https://github.com/pytorch/rl/blob/main/" + tutorialUrlArray.join("/") + ".py", - notebookLink = $(".reference.download")[1].href, - notebookDownloadPath = notebookLink.split('_downloads')[1], - colabLink = "https://colab.research.google.com/github/pytorch/rl/blob/gh-pages/_downloads" + notebookDownloadPath; - - $("#google-colab-link").wrap("
"); - $("#download-notebook-link").wrap(""); - $("#github-view-link").wrap(""); -} else { - $(".pytorch-call-to-action-links").hide(); -} - -//This code handles the Expand/Hide toggle for the Docs/Tutorials left nav items - -$(document).ready(function() { - var caption = "#pytorch-left-menu p.caption"; - var collapseAdded = $(this).not("checked"); - $(caption).each(function () { - var menuName = this.innerText.replace(/[^\w\s]/gi, "").trim(); - $(this).find("span").addClass("checked"); - if (collapsedSections.includes(menuName) == true && collapseAdded && sessionStorage.getItem(menuName) !== "expand" || sessionStorage.getItem(menuName) == "collapse") { - $(this.firstChild).after("[ + ]"); - $(this.firstChild).after("[ - ]"); - $(this).next("ul").hide(); - } else if (collapsedSections.includes(menuName) == false && collapseAdded || sessionStorage.getItem(menuName) == "expand") { - $(this.firstChild).after("[ + ]"); - $(this.firstChild).after("[ - ]"); - } - }); - - $(".expand-menu").on("click", function () { - $(this).prev(".hide-menu").toggle(); - $(this).parent().next("ul").toggle(); - var menuName = $(this).parent().text().replace(/[^\w\s]/gi, "").trim(); - if (sessionStorage.getItem(menuName) == "collapse") { - sessionStorage.removeItem(menuName); - } - sessionStorage.setItem(menuName, "expand"); - toggleList(this); - }); - - $(".hide-menu").on("click", function () { - $(this).next(".expand-menu").toggle(); - $(this).parent().next("ul").toggle(); - var menuName = $(this).parent().text().replace(/[^\w\s]/gi, "").trim(); - if (sessionStorage.getItem(menuName) == "expand") { - sessionStorage.removeItem(menuName); - } - sessionStorage.setItem(menuName, "collapse"); - toggleList(this); - }); - - function toggleList(menuCommand) { - $(menuCommand).toggle(); - } -}); - -// Build an array from each tag that's present - -var tagList = $(".tutorials-card-container").map(function() { - return $(this).data("tags").split(",").map(function(item) { - return item.trim(); - }); -}).get(); - -function unique(value, index, self) { - return self.indexOf(value) == index && value != "" - } - -// Only return unique tags - -var tags = tagList.sort().filter(unique); - -// Add filter buttons to the top of the page for each tag - -function createTagMenu() { - tags.forEach(function(item){ - $(".tutorial-filter-menu").append("
" + item + "
") - }) -}; - -createTagMenu(); - -// Remove hyphens if they are present in the filter buttons - -$(".tags").each(function(){ - var tags = $(this).text().split(","); - tags.forEach(function(tag, i ) { - tags[i] = tags[i].replace(/-/, ' ') - }) - $(this).html(tags.join(", ")); -}); - -// Remove hyphens if they are present in the card body - -$(".tutorial-filter").each(function(){ - var tag = $(this).text(); - $(this).html(tag.replace(/-/, ' ')) -}) - -// Remove any empty p tags that Sphinx adds - -$("#tutorial-cards p").each(function(index, item) { - if(!$(item).text().trim()) { - $(item).remove(); - } -}); - -// Jump back to top on pagination click - -$(document).on("click", ".page", function() { - $('html, body').animate( - {scrollTop: $("#dropdown-filter-tags").position().top}, - 'slow' - ); -}); - -var link = $("a[href='intermediate/speech_command_recognition_with_torchaudio.html']"); - -if (link.text() == "SyntaxError") { - console.log("There is an issue with the intermediate/speech_command_recognition_with_torchaudio.html menu item."); - link.text("Speech Command Recognition with torchaudio"); -} - -$(".stars-outer > i").hover(function() { - $(this).prevAll().addBack().toggleClass("fas star-fill"); -}); - -$(".stars-outer > i").on("click", function() { - $(this).prevAll().each(function() { - $(this).addBack().addClass("fas star-fill"); - }); - - $(".stars-outer > i").each(function() { - $(this).unbind("mouseenter mouseleave").css({ - "pointer-events": "none" - }); - }); -}) - -$("#pytorch-side-scroll-right li a").on("click", function (e) { - var href = $(this).attr("href"); - $('html, body').stop().animate({ - scrollTop: $(href).offset().top - 100 - }, 850); - e.preventDefault; -}); - -var lastId, - topMenu = $("#pytorch-side-scroll-right"), - topMenuHeight = topMenu.outerHeight() + 1, - // All sidenav items - menuItems = topMenu.find("a"), - // Anchors for menu items - scrollItems = menuItems.map(function () { - var item = $(this).attr("href"); - if (item.length) { - return item; - } - }); - -$(window).scroll(function () { - var fromTop = $(this).scrollTop() + topMenuHeight; - var article = ".section"; - - $(article).each(function (i) { - var offsetScroll = $(this).offset().top - $(window).scrollTop(); - if ( - offsetScroll <= topMenuHeight + 200 && - offsetScroll >= topMenuHeight - 200 && - scrollItems[i] == "#" + $(this).attr("id") && - $(".hidden:visible") - ) { - $(menuItems).removeClass("side-scroll-highlight"); - $(menuItems[i]).addClass("side-scroll-highlight"); - } - }); -}); - - -},{"jquery":"jquery"}],"pytorch-sphinx-theme":[function(require,module,exports){ -require=(function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c="function"==typeof require&&require;if(!f&&c)return c(i,!0);if(u)return u(i,!0);var a=new Error("Cannot find module '"+i+"'");throw a.code="MODULE_NOT_FOUND",a}var p=n[i]={exports:{}};e[i][0].call(p.exports,function(r){var n=e[i][1][r];return o(n||r)},p,p.exports,r,e,n,t)}return n[i].exports}for(var u="function"==typeof require&&require,i=0;i wait) { - if (timeout) { - clearTimeout(timeout); - timeout = null; - } - previous = now; - result = func.apply(context, args); - if (!timeout) context = args = null; - } else if (!timeout && options.trailing !== false) { - timeout = setTimeout(later, remaining); - } - return result; - }; - }, - - closest: function (el, selector) { - var matchesFn; - - // find vendor prefix - ['matches','webkitMatchesSelector','mozMatchesSelector','msMatchesSelector','oMatchesSelector'].some(function(fn) { - if (typeof document.body[fn] == 'function') { - matchesFn = fn; - return true; - } - return false; - }); - - var parent; - - // traverse parents - while (el) { - parent = el.parentElement; - if (parent && parent[matchesFn](selector)) { - return parent; - } - el = parent; - } - - return null; - }, - - // Modified from https://stackoverflow.com/a/18953277 - offset: function(elem) { - if (!elem) { - return; - } - - rect = elem.getBoundingClientRect(); - - // Make sure element is not hidden (display: none) or disconnected - if (rect.width || rect.height || elem.getClientRects().length) { - var doc = elem.ownerDocument; - var docElem = doc.documentElement; - - return { - top: rect.top + window.pageYOffset - docElem.clientTop, - left: rect.left + window.pageXOffset - docElem.clientLeft - }; - } - }, - - headersHeight: function() { - if (document.getElementById("pytorch-left-menu").classList.contains("make-fixed")) { - return document.getElementById("pytorch-page-level-bar").offsetHeight; - } else { - return document.getElementById("header-holder").offsetHeight + - document.getElementById("pytorch-page-level-bar").offsetHeight; - } - }, - - windowHeight: function() { - return window.innerHeight || - document.documentElement.clientHeight || - document.body.clientHeight; - } - } - - },{}],2:[function(require,module,exports){ - var cookieBanner = { - init: function() { - cookieBanner.bind(); - - var cookieExists = cookieBanner.cookieExists(); - - if (!cookieExists) { - cookieBanner.setCookie(); - cookieBanner.showCookieNotice(); - } - }, - - bind: function() { - $(".close-button").on("click", cookieBanner.hideCookieNotice); - }, - - cookieExists: function() { - var cookie = localStorage.getItem("returningPytorchUser"); - - if (cookie) { - return true; - } else { - return false; - } - }, - - setCookie: function() { - localStorage.setItem("returningPytorchUser", true); - }, - - showCookieNotice: function() { - $(".cookie-banner-wrapper").addClass("is-visible"); - }, - - hideCookieNotice: function() { - $(".cookie-banner-wrapper").removeClass("is-visible"); - } - }; - - $(function() { - cookieBanner.init(); - }); - - },{}],3:[function(require,module,exports){ - window.filterTags = { - bind: function() { - var options = { - valueNames: [{ data: ["tags"] }], - page: "6", - pagination: true - }; - - var tutorialList = new List("tutorial-cards", options); - - function filterSelectedTags(cardTags, selectedTags) { - return cardTags.some(function(tag) { - return selectedTags.some(function(selectedTag) { - return selectedTag == tag; - }); - }); - } - - function updateList() { - var selectedTags = []; - - $(".selected").each(function() { - selectedTags.push($(this).data("tag")); - }); - - tutorialList.filter(function(item) { - var cardTags; - - if (item.values().tags == null) { - cardTags = [""]; - } else { - cardTags = item.values().tags.split(","); - } - - if (selectedTags.length == 0) { - return true; - } else { - return filterSelectedTags(cardTags, selectedTags); - } - }); - } - - $(".filter-btn").on("click", function() { - if ($(this).data("tag") == "all") { - $(this).addClass("all-tag-selected"); - $(".filter").removeClass("selected"); - } else { - $(this).toggleClass("selected"); - $("[data-tag='all']").removeClass("all-tag-selected"); - } - - // If no tags are selected then highlight the 'All' tag - - if (!$(".selected")[0]) { - $("[data-tag='all']").addClass("all-tag-selected"); - } - - updateList(); - }); - } - }; - - },{}],4:[function(require,module,exports){ - // Modified from https://stackoverflow.com/a/32396543 - window.highlightNavigation = { - navigationListItems: document.querySelectorAll("#pytorch-right-menu li"), - sections: document.querySelectorAll(".pytorch-article .section"), - sectionIdTonavigationLink: {}, - - bind: function() { - if (!sideMenus.displayRightMenu) { - return; - }; - - for (var i = 0; i < highlightNavigation.sections.length; i++) { - var id = highlightNavigation.sections[i].id; - highlightNavigation.sectionIdTonavigationLink[id] = - document.querySelectorAll('#pytorch-right-menu li a[href="#' + id + '"]')[0]; - } - - $(window).scroll(utilities.throttle(highlightNavigation.highlight, 100)); - }, - - highlight: function() { - var rightMenu = document.getElementById("pytorch-right-menu"); - - // If right menu is not on the screen don't bother - if (rightMenu.offsetWidth === 0 && rightMenu.offsetHeight === 0) { - return; - } - - var scrollPosition = utilities.scrollTop(); - var OFFSET_TOP_PADDING = 25; - var offset = document.getElementById("header-holder").offsetHeight + - document.getElementById("pytorch-page-level-bar").offsetHeight + - OFFSET_TOP_PADDING; - - var sections = highlightNavigation.sections; - - for (var i = (sections.length - 1); i >= 0; i--) { - var currentSection = sections[i]; - var sectionTop = utilities.offset(currentSection).top; - - if (scrollPosition >= sectionTop - offset) { - var navigationLink = highlightNavigation.sectionIdTonavigationLink[currentSection.id]; - var navigationListItem = utilities.closest(navigationLink, "li"); - - if (navigationListItem && !navigationListItem.classList.contains("active")) { - for (var i = 0; i < highlightNavigation.navigationListItems.length; i++) { - var el = highlightNavigation.navigationListItems[i]; - if (el.classList.contains("active")) { - el.classList.remove("active"); - } - } - - navigationListItem.classList.add("active"); - - // Scroll to active item. Not a requested feature but we could revive it. Needs work. - - // var menuTop = $("#pytorch-right-menu").position().top; - // var itemTop = navigationListItem.getBoundingClientRect().top; - // var TOP_PADDING = 20 - // var newActiveTop = $("#pytorch-side-scroll-right").scrollTop() + itemTop - menuTop - TOP_PADDING; - - // $("#pytorch-side-scroll-right").animate({ - // scrollTop: newActiveTop - // }, 100); - } - - break; - } - } - } - }; - - },{}],5:[function(require,module,exports){ - window.mainMenuDropdown = { - bind: function() { - $("[data-toggle='ecosystem-dropdown']").on("click", function() { - toggleDropdown($(this).attr("data-toggle")); - }); - - $("[data-toggle='resources-dropdown']").on("click", function() { - toggleDropdown($(this).attr("data-toggle")); - }); - - function toggleDropdown(menuToggle) { - var showMenuClass = "show-menu"; - var menuClass = "." + menuToggle + "-menu"; - - if ($(menuClass).hasClass(showMenuClass)) { - $(menuClass).removeClass(showMenuClass); - } else { - $("[data-toggle=" + menuToggle + "].show-menu").removeClass( - showMenuClass - ); - $(menuClass).addClass(showMenuClass); - } - } - } - }; - - },{}],6:[function(require,module,exports){ - window.mobileMenu = { - bind: function() { - $("[data-behavior='open-mobile-menu']").on('click', function(e) { - e.preventDefault(); - $(".mobile-main-menu").addClass("open"); - $("body").addClass('no-scroll'); - - mobileMenu.listenForResize(); - }); - - $("[data-behavior='close-mobile-menu']").on('click', function(e) { - e.preventDefault(); - mobileMenu.close(); - }); - }, - - listenForResize: function() { - $(window).on('resize.ForMobileMenu', function() { - if ($(this).width() > 768) { - mobileMenu.close(); - } - }); - }, - - close: function() { - $(".mobile-main-menu").removeClass("open"); - $("body").removeClass('no-scroll'); - $(window).off('resize.ForMobileMenu'); - } - }; - - },{}],7:[function(require,module,exports){ - window.mobileTOC = { - bind: function() { - $("[data-behavior='toggle-table-of-contents']").on("click", function(e) { - e.preventDefault(); - - var $parent = $(this).parent(); - - if ($parent.hasClass("is-open")) { - $parent.removeClass("is-open"); - $(".pytorch-left-menu").slideUp(200, function() { - $(this).css({display: ""}); - }); - } else { - $parent.addClass("is-open"); - $(".pytorch-left-menu").slideDown(200); - } - }); - } - } - - },{}],8:[function(require,module,exports){ - window.pytorchAnchors = { - bind: function() { - // Replace Sphinx-generated anchors with anchorjs ones - $(".headerlink").text(""); - - window.anchors.add(".pytorch-article .headerlink"); - - $(".anchorjs-link").each(function() { - var $headerLink = $(this).closest(".headerlink"); - var href = $headerLink.attr("href"); - var clone = this.outerHTML; - - $clone = $(clone).attr("href", href); - $headerLink.before($clone); - $headerLink.remove(); - }); - } - }; - - },{}],9:[function(require,module,exports){ - // Modified from https://stackoverflow.com/a/13067009 - // Going for a JS solution to scrolling to an anchor so we can benefit from - // less hacky css and smooth scrolling. - - window.scrollToAnchor = { - bind: function() { - var document = window.document; - var history = window.history; - var location = window.location - var HISTORY_SUPPORT = !!(history && history.pushState); - - var anchorScrolls = { - ANCHOR_REGEX: /^#[^ ]+$/, - offsetHeightPx: function() { - var OFFSET_HEIGHT_PADDING = 20; - // TODO: this is a little janky. We should try to not rely on JS for this - return utilities.headersHeight() + OFFSET_HEIGHT_PADDING; - }, - - /** - * Establish events, and fix initial scroll position if a hash is provided. - */ - init: function() { - this.scrollToCurrent(); - // This interferes with clicks below it, causing a double fire - // $(window).on('hashchange', $.proxy(this, 'scrollToCurrent')); - $('body').on('click', 'a', $.proxy(this, 'delegateAnchors')); - $('body').on('click', '#pytorch-right-menu li span', $.proxy(this, 'delegateSpans')); - }, - - /** - * Return the offset amount to deduct from the normal scroll position. - * Modify as appropriate to allow for dynamic calculations - */ - getFixedOffset: function() { - return this.offsetHeightPx(); - }, - - /** - * If the provided href is an anchor which resolves to an element on the - * page, scroll to it. - * @param {String} href - * @return {Boolean} - Was the href an anchor. - */ - scrollIfAnchor: function(href, pushToHistory) { - var match, anchorOffset; - - if(!this.ANCHOR_REGEX.test(href)) { - return false; - } - - match = document.getElementById(href.slice(1)); - - if(match) { - var anchorOffset = $(match).offset().top - this.getFixedOffset(); - - $('html, body').scrollTop(anchorOffset); - - // Add the state to history as-per normal anchor links - if(HISTORY_SUPPORT && pushToHistory) { - history.pushState({}, document.title, location.pathname + href); - } - } - - return !!match; - }, - - /** - * Attempt to scroll to the current location's hash. - */ - scrollToCurrent: function(e) { - if(this.scrollIfAnchor(window.location.hash) && e) { - e.preventDefault(); - } - }, - - delegateSpans: function(e) { - var elem = utilities.closest(e.target, "a"); - - if(this.scrollIfAnchor(elem.getAttribute('href'), true)) { - e.preventDefault(); - } - }, - - /** - * If the click event's target was an anchor, fix the scroll position. - */ - delegateAnchors: function(e) { - var elem = e.target; - - if(this.scrollIfAnchor(elem.getAttribute('href'), true)) { - e.preventDefault(); - } - } - }; - - $(document).ready($.proxy(anchorScrolls, 'init')); - } - }; - - },{}],10:[function(require,module,exports){ - window.sideMenus = { - rightMenuIsOnScreen: function() { - return document.getElementById("pytorch-content-right").offsetParent !== null; - }, - - isFixedToBottom: false, - - bind: function() { - sideMenus.handleLeftMenu(); - - var rightMenuLinks = document.querySelectorAll("#pytorch-right-menu li"); - var rightMenuHasLinks = rightMenuLinks.length > 1; - - if (!rightMenuHasLinks) { - for (var i = 0; i < rightMenuLinks.length; i++) { - rightMenuLinks[i].style.display = "none"; - } - } - - if (rightMenuHasLinks) { - // Don't show the Shortcuts menu title text unless there are menu items - document.getElementById("pytorch-shortcuts-wrapper").style.display = "block"; - - // We are hiding the titles of the pages in the right side menu but there are a few - // pages that include other pages in the right side menu (see 'torch.nn' in the docs) - // so if we exclude those it looks confusing. Here we add a 'title-link' class to these - // links so we can exclude them from normal right side menu link operations - var titleLinks = document.querySelectorAll( - "#pytorch-right-menu #pytorch-side-scroll-right \ - > ul > li > a.reference.internal" - ); - - for (var i = 0; i < titleLinks.length; i++) { - var link = titleLinks[i]; - - link.classList.add("title-link"); - - if ( - link.nextElementSibling && - link.nextElementSibling.tagName === "UL" && - link.nextElementSibling.children.length > 0 - ) { - link.classList.add("has-children"); - } - } - - // Add + expansion signifiers to normal right menu links that have sub menus - var menuLinks = document.querySelectorAll( - "#pytorch-right-menu ul li ul li a.reference.internal" - ); - - for (var i = 0; i < menuLinks.length; i++) { - if ( - menuLinks[i].nextElementSibling && - menuLinks[i].nextElementSibling.tagName === "UL" - ) { - menuLinks[i].classList.add("not-expanded"); - } - } - - // If a hash is present on page load recursively expand menu items leading to selected item - var linkWithHash = - document.querySelector( - "#pytorch-right-menu a[href=\"" + window.location.hash + "\"]" - ); - - if (linkWithHash) { - // Expand immediate sibling list if present - if ( - linkWithHash.nextElementSibling && - linkWithHash.nextElementSibling.tagName === "UL" && - linkWithHash.nextElementSibling.children.length > 0 - ) { - linkWithHash.nextElementSibling.style.display = "block"; - linkWithHash.classList.add("expanded"); - } - - // Expand ancestor lists if any - sideMenus.expandClosestUnexpandedParentList(linkWithHash); - } - - // Bind click events on right menu links - $("#pytorch-right-menu a.reference.internal").on("click", function() { - if (this.classList.contains("expanded")) { - this.nextElementSibling.style.display = "none"; - this.classList.remove("expanded"); - this.classList.add("not-expanded"); - } else if (this.classList.contains("not-expanded")) { - this.nextElementSibling.style.display = "block"; - this.classList.remove("not-expanded"); - this.classList.add("expanded"); - } - }); - - sideMenus.handleRightMenu(); - } - - $(window).on('resize scroll', function(e) { - sideMenus.handleNavBar(); - - sideMenus.handleLeftMenu(); - - if (sideMenus.rightMenuIsOnScreen()) { - sideMenus.handleRightMenu(); - } - }); - }, - - leftMenuIsFixed: function() { - return document.getElementById("pytorch-left-menu").classList.contains("make-fixed"); - }, - - handleNavBar: function() { - var mainHeaderHeight = document.getElementById('header-holder').offsetHeight; - - // If we are scrolled past the main navigation header fix the sub menu bar to top of page - if (utilities.scrollTop() >= mainHeaderHeight) { - document.getElementById("pytorch-left-menu").classList.add("make-fixed"); - document.getElementById("pytorch-page-level-bar").classList.add("left-menu-is-fixed"); - } else { - document.getElementById("pytorch-left-menu").classList.remove("make-fixed"); - document.getElementById("pytorch-page-level-bar").classList.remove("left-menu-is-fixed"); - } - }, - - expandClosestUnexpandedParentList: function (el) { - var closestParentList = utilities.closest(el, "ul"); - - if (closestParentList) { - var closestParentLink = closestParentList.previousElementSibling; - var closestParentLinkExists = closestParentLink && - closestParentLink.tagName === "A" && - closestParentLink.classList.contains("reference"); - - if (closestParentLinkExists) { - // Don't add expansion class to any title links - if (closestParentLink.classList.contains("title-link")) { - return; - } - - closestParentList.style.display = "block"; - closestParentLink.classList.remove("not-expanded"); - closestParentLink.classList.add("expanded"); - sideMenus.expandClosestUnexpandedParentList(closestParentLink); - } - } - }, - - handleLeftMenu: function () { - var windowHeight = utilities.windowHeight(); - var topOfFooterRelativeToWindow = document.getElementById("docs-tutorials-resources").getBoundingClientRect().top; - - if (topOfFooterRelativeToWindow >= windowHeight) { - document.getElementById("pytorch-left-menu").style.height = "100%"; - } else { - var howManyPixelsOfTheFooterAreInTheWindow = windowHeight - topOfFooterRelativeToWindow; - var leftMenuDifference = howManyPixelsOfTheFooterAreInTheWindow; - document.getElementById("pytorch-left-menu").style.height = (windowHeight - leftMenuDifference) + "px"; - } - }, - - handleRightMenu: function() { - var rightMenuWrapper = document.getElementById("pytorch-content-right"); - var rightMenu = document.getElementById("pytorch-right-menu"); - var rightMenuList = rightMenu.getElementsByTagName("ul")[0]; - var article = document.getElementById("pytorch-article"); - var articleHeight = article.offsetHeight; - var articleBottom = utilities.offset(article).top + articleHeight; - var mainHeaderHeight = document.getElementById('header-holder').offsetHeight; - - if (utilities.scrollTop() < mainHeaderHeight) { - rightMenuWrapper.style.height = "100%"; - rightMenu.style.top = 0; - rightMenu.classList.remove("scrolling-fixed"); - rightMenu.classList.remove("scrolling-absolute"); - } else { - if (rightMenu.classList.contains("scrolling-fixed")) { - var rightMenuBottom = - utilities.offset(rightMenuList).top + rightMenuList.offsetHeight; - - if (rightMenuBottom >= articleBottom) { - rightMenuWrapper.style.height = articleHeight + mainHeaderHeight + "px"; - rightMenu.style.top = utilities.scrollTop() - mainHeaderHeight + "px"; - rightMenu.classList.add("scrolling-absolute"); - rightMenu.classList.remove("scrolling-fixed"); - } - } else { - rightMenuWrapper.style.height = articleHeight + mainHeaderHeight + "px"; - rightMenu.style.top = - articleBottom - mainHeaderHeight - rightMenuList.offsetHeight + "px"; - rightMenu.classList.add("scrolling-absolute"); - } - - if (utilities.scrollTop() < articleBottom - rightMenuList.offsetHeight) { - rightMenuWrapper.style.height = "100%"; - rightMenu.style.top = ""; - rightMenu.classList.remove("scrolling-absolute"); - rightMenu.classList.add("scrolling-fixed"); - } - } - - var rightMenuSideScroll = document.getElementById("pytorch-side-scroll-right"); - var sideScrollFromWindowTop = rightMenuSideScroll.getBoundingClientRect().top; - - rightMenuSideScroll.style.height = utilities.windowHeight() - sideScrollFromWindowTop + "px"; - } - }; - - },{}],11:[function(require,module,exports){ - var jQuery = (typeof(window) != 'undefined') ? window.jQuery : require('jquery'); - - // Sphinx theme nav state - function ThemeNav () { - - var nav = { - navBar: null, - win: null, - winScroll: false, - winResize: false, - linkScroll: false, - winPosition: 0, - winHeight: null, - docHeight: null, - isRunning: false - }; - - nav.enable = function (withStickyNav) { - var self = this; - - // TODO this can likely be removed once the theme javascript is broken - // out from the RTD assets. This just ensures old projects that are - // calling `enable()` get the sticky menu on by default. All other cals - // to `enable` should include an argument for enabling the sticky menu. - if (typeof(withStickyNav) == 'undefined') { - withStickyNav = true; - } - - if (self.isRunning) { - // Only allow enabling nav logic once - return; - } - - self.isRunning = true; - jQuery(function ($) { - self.init($); - - self.reset(); - self.win.on('hashchange', self.reset); - - if (withStickyNav) { - // Set scroll monitor - self.win.on('scroll', function () { - if (!self.linkScroll) { - if (!self.winScroll) { - self.winScroll = true; - requestAnimationFrame(function() { self.onScroll(); }); - } - } - }); - } - - // Set resize monitor - self.win.on('resize', function () { - if (!self.winResize) { - self.winResize = true; - requestAnimationFrame(function() { self.onResize(); }); - } - }); - - self.onResize(); - }); - - }; - - // TODO remove this with a split in theme and Read the Docs JS logic as - // well, it's only here to support 0.3.0 installs of our theme. - nav.enableSticky = function() { - this.enable(true); - }; - - nav.init = function ($) { - var doc = $(document), - self = this; - - this.navBar = $('div.pytorch-side-scroll:first'); - this.win = $(window); - - // Set up javascript UX bits - $(document) - // Shift nav in mobile when clicking the menu. - .on('click', "[data-toggle='pytorch-left-menu-nav-top']", function() { - $("[data-toggle='wy-nav-shift']").toggleClass("shift"); - $("[data-toggle='rst-versions']").toggleClass("shift"); - }) - - // Nav menu link click operations - .on('click', ".pytorch-menu-vertical .current ul li a", function() { - var target = $(this); - // Close menu when you click a link. - $("[data-toggle='wy-nav-shift']").removeClass("shift"); - $("[data-toggle='rst-versions']").toggleClass("shift"); - // Handle dynamic display of l3 and l4 nav lists - self.toggleCurrent(target); - self.hashChange(); - }) - .on('click', "[data-toggle='rst-current-version']", function() { - $("[data-toggle='rst-versions']").toggleClass("shift-up"); - }) - - // Make tables responsive - $("table.docutils:not(.field-list,.footnote,.citation)") - .wrap("
"); - - // Add extra class to responsive tables that contain - // footnotes or citations so that we can target them for styling - $("table.docutils.footnote") - .wrap("
"); - $("table.docutils.citation") - .wrap("
"); - - // Add expand links to all parents of nested ul - $('.pytorch-menu-vertical ul').not('.simple').siblings('a').each(function () { - var link = $(this); - expand = $(''); - expand.on('click', function (ev) { - self.toggleCurrent(link); - ev.stopPropagation(); - return false; - }); - link.prepend(expand); - }); - }; - - nav.reset = function () { - // Get anchor from URL and open up nested nav - var anchor = encodeURI(window.location.hash) || '#'; - - try { - var vmenu = $('.pytorch-menu-vertical'); - var link = vmenu.find('[href="' + anchor + '"]'); - if (link.length === 0) { - // this link was not found in the sidebar. - // Find associated id element, then its closest section - // in the document and try with that one. - var id_elt = $('.document [id="' + anchor.substring(1) + '"]'); - var closest_section = id_elt.closest('div.section'); - link = vmenu.find('[href="#' + closest_section.attr("id") + '"]'); - if (link.length === 0) { - // still not found in the sidebar. fall back to main section - link = vmenu.find('[href="#"]'); - } - } - // If we found a matching link then reset current and re-apply - // otherwise retain the existing match - if (link.length > 0) { - $('.pytorch-menu-vertical .current').removeClass('current'); - link.addClass('current'); - link.closest('li.toctree-l1').addClass('current'); - link.closest('li.toctree-l1').parent().addClass('current'); - link.closest('li.toctree-l1').addClass('current'); - link.closest('li.toctree-l2').addClass('current'); - link.closest('li.toctree-l3').addClass('current'); - link.closest('li.toctree-l4').addClass('current'); - } - } - catch (err) { - console.log("Error expanding nav for anchor", err); - } - - }; - - nav.onScroll = function () { - this.winScroll = false; - var newWinPosition = this.win.scrollTop(), - winBottom = newWinPosition + this.winHeight, - navPosition = this.navBar.scrollTop(), - newNavPosition = navPosition + (newWinPosition - this.winPosition); - if (newWinPosition < 0 || winBottom > this.docHeight) { - return; - } - this.navBar.scrollTop(newNavPosition); - this.winPosition = newWinPosition; - }; - - nav.onResize = function () { - this.winResize = false; - this.winHeight = this.win.height(); - this.docHeight = $(document).height(); - }; - - nav.hashChange = function () { - this.linkScroll = true; - this.win.one('hashchange', function () { - this.linkScroll = false; - }); - }; - - nav.toggleCurrent = function (elem) { - var parent_li = elem.closest('li'); - parent_li.siblings('li.current').removeClass('current'); - parent_li.siblings().find('li.current').removeClass('current'); - parent_li.find('> ul li.current').removeClass('current'); - parent_li.toggleClass('current'); - } - - return nav; - }; - - module.exports.ThemeNav = ThemeNav(); - - if (typeof(window) != 'undefined') { - window.SphinxRtdTheme = { - Navigation: module.exports.ThemeNav, - // TODO remove this once static assets are split up between the theme - // and Read the Docs. For now, this patches 0.3.0 to be backwards - // compatible with a pre-0.3.0 layout.html - StickyNav: module.exports.ThemeNav, - }; - } - - - // requestAnimationFrame polyfill by Erik Möller. fixes from Paul Irish and Tino Zijdel - // https://gist.github.com/paulirish/1579671 - // MIT license - - (function() { - var lastTime = 0; - var vendors = ['ms', 'moz', 'webkit', 'o']; - for(var x = 0; x < vendors.length && !window.requestAnimationFrame; ++x) { - window.requestAnimationFrame = window[vendors[x]+'RequestAnimationFrame']; - window.cancelAnimationFrame = window[vendors[x]+'CancelAnimationFrame'] - || window[vendors[x]+'CancelRequestAnimationFrame']; - } - - if (!window.requestAnimationFrame) - window.requestAnimationFrame = function(callback, element) { - var currTime = new Date().getTime(); - var timeToCall = Math.max(0, 16 - (currTime - lastTime)); - var id = window.setTimeout(function() { callback(currTime + timeToCall); }, - timeToCall); - lastTime = currTime + timeToCall; - return id; - }; - - if (!window.cancelAnimationFrame) - window.cancelAnimationFrame = function(id) { - clearTimeout(id); - }; - }()); - - $(".sphx-glr-thumbcontainer").removeAttr("tooltip"); - $("table").removeAttr("border"); - - // This code replaces the default sphinx gallery download buttons - // with the 3 download buttons at the top of the page - - var downloadNote = $(".sphx-glr-download-link-note.admonition.note"); - if (downloadNote.length >= 1) { - var tutorialUrlArray = $("#tutorial-type").text().split('/'); - - var githubLink = "https://github.com/pytorch/rl/tree/tutorial_py_dup/sphinx-tutorials/" + tutorialUrlArray[tutorialUrlArray.length - 1] + ".py", - notebookLink = $(".reference.download")[1].href, - notebookDownloadPath = notebookLink.split('_downloads')[1], - colabLink = "https://colab.research.google.com/github/pytorch/rl/blob/gh-pages/_downloads" + notebookDownloadPath; - - $("#google-colab-link").wrap("
"); - $("#download-notebook-link").wrap(""); - $("#github-view-link").wrap(""); - } else { - $(".pytorch-call-to-action-links").hide(); - } - - //This code handles the Expand/Hide toggle for the Docs/Tutorials left nav items - - $(document).ready(function() { - var caption = "#pytorch-left-menu p.caption"; - var collapseAdded = $(this).not("checked"); - $(caption).each(function () { - var menuName = this.innerText.replace(/[^\w\s]/gi, "").trim(); - $(this).find("span").addClass("checked"); - if (collapsedSections.includes(menuName) == true && collapseAdded && sessionStorage.getItem(menuName) !== "expand" || sessionStorage.getItem(menuName) == "collapse") { - $(this.firstChild).after("[ + ]"); - $(this.firstChild).after("[ - ]"); - $(this).next("ul").hide(); - } else if (collapsedSections.includes(menuName) == false && collapseAdded || sessionStorage.getItem(menuName) == "expand") { - $(this.firstChild).after("[ + ]"); - $(this.firstChild).after("[ - ]"); - } - }); - - $(".expand-menu").on("click", function () { - $(this).prev(".hide-menu").toggle(); - $(this).parent().next("ul").toggle(); - var menuName = $(this).parent().text().replace(/[^\w\s]/gi, "").trim(); - if (sessionStorage.getItem(menuName) == "collapse") { - sessionStorage.removeItem(menuName); - } - sessionStorage.setItem(menuName, "expand"); - toggleList(this); - }); - - $(".hide-menu").on("click", function () { - $(this).next(".expand-menu").toggle(); - $(this).parent().next("ul").toggle(); - var menuName = $(this).parent().text().replace(/[^\w\s]/gi, "").trim(); - if (sessionStorage.getItem(menuName) == "expand") { - sessionStorage.removeItem(menuName); - } - sessionStorage.setItem(menuName, "collapse"); - toggleList(this); - }); - - function toggleList(menuCommand) { - $(menuCommand).toggle(); - } - }); - - // Build an array from each tag that's present - - var tagList = $(".tutorials-card-container").map(function() { - return $(this).data("tags").split(",").map(function(item) { - return item.trim(); - }); - }).get(); - - function unique(value, index, self) { - return self.indexOf(value) == index && value != "" - } - - // Only return unique tags - - var tags = tagList.sort().filter(unique); - - // Add filter buttons to the top of the page for each tag - - function createTagMenu() { - tags.forEach(function(item){ - $(".tutorial-filter-menu").append("
" + item + "
") - }) - }; - - createTagMenu(); - - // Remove hyphens if they are present in the filter buttons - - $(".tags").each(function(){ - var tags = $(this).text().split(","); - tags.forEach(function(tag, i ) { - tags[i] = tags[i].replace(/-/, ' ') - }) - $(this).html(tags.join(", ")); - }); - - // Remove hyphens if they are present in the card body - - $(".tutorial-filter").each(function(){ - var tag = $(this).text(); - $(this).html(tag.replace(/-/, ' ')) - }) - - // Remove any empty p tags that Sphinx adds - - $("#tutorial-cards p").each(function(index, item) { - if(!$(item).text().trim()) { - $(item).remove(); - } - }); - - // Jump back to top on pagination click - - $(document).on("click", ".page", function() { - $('html, body').animate( - {scrollTop: $("#dropdown-filter-tags").position().top}, - 'slow' - ); - }); - - var link = $("a[href='intermediate/speech_command_recognition_with_torchaudio.html']"); - - if (link.text() == "SyntaxError") { - console.log("There is an issue with the intermediate/speech_command_recognition_with_torchaudio.html menu item."); - link.text("Speech Command Recognition with torchaudio"); - } - - $(".stars-outer > i").hover(function() { - $(this).prevAll().addBack().toggleClass("fas star-fill"); - }); - - $(".stars-outer > i").on("click", function() { - $(this).prevAll().each(function() { - $(this).addBack().addClass("fas star-fill"); - }); - - $(".stars-outer > i").each(function() { - $(this).unbind("mouseenter mouseleave").css({ - "pointer-events": "none" - }); - }); - }) - - $("#pytorch-side-scroll-right li a").on("click", function (e) { - var href = $(this).attr("href"); - $('html, body').stop().animate({ - scrollTop: $(href).offset().top - 100 - }, 850); - e.preventDefault; - }); - - var lastId, - topMenu = $("#pytorch-side-scroll-right"), - topMenuHeight = topMenu.outerHeight() + 1, - // All sidenav items - menuItems = topMenu.find("a"), - // Anchors for menu items - scrollItems = menuItems.map(function () { - var item = $(this).attr("href"); - if (item.length) { - return item; - } - }); - - $(window).scroll(function () { - var fromTop = $(this).scrollTop() + topMenuHeight; - var article = ".section"; - - $(article).each(function (i) { - var offsetScroll = $(this).offset().top - $(window).scrollTop(); - if ( - offsetScroll <= topMenuHeight + 200 && - offsetScroll >= topMenuHeight - 200 && - scrollItems[i] == "#" + $(this).attr("id") && - $(".hidden:visible") - ) { - $(menuItems).removeClass("side-scroll-highlight"); - $(menuItems[i]).addClass("side-scroll-highlight"); - } - }); - }); - - - },{"jquery":"jquery"}],"pytorch-sphinx-theme":[function(require,module,exports){ - var jQuery = (typeof(window) != 'undefined') ? window.jQuery : require('jquery'); - - // Sphinx theme nav state - function ThemeNav () { - - var nav = { - navBar: null, - win: null, - winScroll: false, - winResize: false, - linkScroll: false, - winPosition: 0, - winHeight: null, - docHeight: null, - isRunning: false - }; - - nav.enable = function (withStickyNav) { - var self = this; - - // TODO this can likely be removed once the theme javascript is broken - // out from the RTD assets. This just ensures old projects that are - // calling `enable()` get the sticky menu on by default. All other cals - // to `enable` should include an argument for enabling the sticky menu. - if (typeof(withStickyNav) == 'undefined') { - withStickyNav = true; - } - - if (self.isRunning) { - // Only allow enabling nav logic once - return; - } - - self.isRunning = true; - jQuery(function ($) { - self.init($); - - self.reset(); - self.win.on('hashchange', self.reset); - - if (withStickyNav) { - // Set scroll monitor - self.win.on('scroll', function () { - if (!self.linkScroll) { - if (!self.winScroll) { - self.winScroll = true; - requestAnimationFrame(function() { self.onScroll(); }); - } - } - }); - } - - // Set resize monitor - self.win.on('resize', function () { - if (!self.winResize) { - self.winResize = true; - requestAnimationFrame(function() { self.onResize(); }); - } - }); - - self.onResize(); - }); - - }; - - // TODO remove this with a split in theme and Read the Docs JS logic as - // well, it's only here to support 0.3.0 installs of our theme. - nav.enableSticky = function() { - this.enable(true); - }; - - nav.init = function ($) { - var doc = $(document), - self = this; - - this.navBar = $('div.pytorch-side-scroll:first'); - this.win = $(window); - - // Set up javascript UX bits - $(document) - // Shift nav in mobile when clicking the menu. - .on('click', "[data-toggle='pytorch-left-menu-nav-top']", function() { - $("[data-toggle='wy-nav-shift']").toggleClass("shift"); - $("[data-toggle='rst-versions']").toggleClass("shift"); - }) - - // Nav menu link click operations - .on('click', ".pytorch-menu-vertical .current ul li a", function() { - var target = $(this); - // Close menu when you click a link. - $("[data-toggle='wy-nav-shift']").removeClass("shift"); - $("[data-toggle='rst-versions']").toggleClass("shift"); - // Handle dynamic display of l3 and l4 nav lists - self.toggleCurrent(target); - self.hashChange(); - }) - .on('click', "[data-toggle='rst-current-version']", function() { - $("[data-toggle='rst-versions']").toggleClass("shift-up"); - }) - - // Make tables responsive - $("table.docutils:not(.field-list,.footnote,.citation)") - .wrap("
"); - - // Add extra class to responsive tables that contain - // footnotes or citations so that we can target them for styling - $("table.docutils.footnote") - .wrap("
"); - $("table.docutils.citation") - .wrap("
"); - - // Add expand links to all parents of nested ul - $('.pytorch-menu-vertical ul').not('.simple').siblings('a').each(function () { - var link = $(this); - expand = $(''); - expand.on('click', function (ev) { - self.toggleCurrent(link); - ev.stopPropagation(); - return false; - }); - link.prepend(expand); - }); - }; - - nav.reset = function () { - // Get anchor from URL and open up nested nav - var anchor = encodeURI(window.location.hash) || '#'; - - try { - var vmenu = $('.pytorch-menu-vertical'); - var link = vmenu.find('[href="' + anchor + '"]'); - if (link.length === 0) { - // this link was not found in the sidebar. - // Find associated id element, then its closest section - // in the document and try with that one. - var id_elt = $('.document [id="' + anchor.substring(1) + '"]'); - var closest_section = id_elt.closest('div.section'); - link = vmenu.find('[href="#' + closest_section.attr("id") + '"]'); - if (link.length === 0) { - // still not found in the sidebar. fall back to main section - link = vmenu.find('[href="#"]'); - } - } - // If we found a matching link then reset current and re-apply - // otherwise retain the existing match - if (link.length > 0) { - $('.pytorch-menu-vertical .current').removeClass('current'); - link.addClass('current'); - link.closest('li.toctree-l1').addClass('current'); - link.closest('li.toctree-l1').parent().addClass('current'); - link.closest('li.toctree-l1').addClass('current'); - link.closest('li.toctree-l2').addClass('current'); - link.closest('li.toctree-l3').addClass('current'); - link.closest('li.toctree-l4').addClass('current'); - } - } - catch (err) { - console.log("Error expanding nav for anchor", err); - } - - }; - - nav.onScroll = function () { - this.winScroll = false; - var newWinPosition = this.win.scrollTop(), - winBottom = newWinPosition + this.winHeight, - navPosition = this.navBar.scrollTop(), - newNavPosition = navPosition + (newWinPosition - this.winPosition); - if (newWinPosition < 0 || winBottom > this.docHeight) { - return; - } - this.navBar.scrollTop(newNavPosition); - this.winPosition = newWinPosition; - }; - - nav.onResize = function () { - this.winResize = false; - this.winHeight = this.win.height(); - this.docHeight = $(document).height(); - }; - - nav.hashChange = function () { - this.linkScroll = true; - this.win.one('hashchange', function () { - this.linkScroll = false; - }); - }; - - nav.toggleCurrent = function (elem) { - var parent_li = elem.closest('li'); - parent_li.siblings('li.current').removeClass('current'); - parent_li.siblings().find('li.current').removeClass('current'); - parent_li.find('> ul li.current').removeClass('current'); - parent_li.toggleClass('current'); - } - - return nav; - }; - - module.exports.ThemeNav = ThemeNav(); - - if (typeof(window) != 'undefined') { - window.SphinxRtdTheme = { - Navigation: module.exports.ThemeNav, - // TODO remove this once static assets are split up between the theme - // and Read the Docs. For now, this patches 0.3.0 to be backwards - // compatible with a pre-0.3.0 layout.html - StickyNav: module.exports.ThemeNav, - }; - } - - - // requestAnimationFrame polyfill by Erik Möller. fixes from Paul Irish and Tino Zijdel - // https://gist.github.com/paulirish/1579671 - // MIT license - - (function() { - var lastTime = 0; - var vendors = ['ms', 'moz', 'webkit', 'o']; - for(var x = 0; x < vendors.length && !window.requestAnimationFrame; ++x) { - window.requestAnimationFrame = window[vendors[x]+'RequestAnimationFrame']; - window.cancelAnimationFrame = window[vendors[x]+'CancelAnimationFrame'] - || window[vendors[x]+'CancelRequestAnimationFrame']; - } - - if (!window.requestAnimationFrame) - window.requestAnimationFrame = function(callback, element) { - var currTime = new Date().getTime(); - var timeToCall = Math.max(0, 16 - (currTime - lastTime)); - var id = window.setTimeout(function() { callback(currTime + timeToCall); }, - timeToCall); - lastTime = currTime + timeToCall; - return id; - }; - - if (!window.cancelAnimationFrame) - window.cancelAnimationFrame = function(id) { - clearTimeout(id); - }; - }()); - - $(".sphx-glr-thumbcontainer").removeAttr("tooltip"); - $("table").removeAttr("border"); - - // This code replaces the default sphinx gallery download buttons - // with the 3 download buttons at the top of the page - - var downloadNote = $(".sphx-glr-download-link-note.admonition.note"); - if (downloadNote.length >= 1) { - var tutorialUrlArray = $("#tutorial-type").text().split('/'); - - var githubLink = "https://github.com/pytorch/rl/tree/tutorial_py_dup/tutorials/" + tutorialUrlArray.join("/") + ".py", - notebookLink = $(".reference.download")[1].href, - notebookDownloadPath = notebookLink.split('_downloads')[1], - colabLink = "https://colab.research.google.com/github/pytorch/rl/blob/gh-pages/_downloads" + notebookDownloadPath; - - $("#google-colab-link").wrap("
"); - $("#download-notebook-link").wrap(""); - $("#github-view-link").wrap(""); - } else { - $(".pytorch-call-to-action-links").hide(); - } - - //This code handles the Expand/Hide toggle for the Docs/Tutorials left nav items - - $(document).ready(function() { - var caption = "#pytorch-left-menu p.caption"; - var collapseAdded = $(this).not("checked"); - $(caption).each(function () { - var menuName = this.innerText.replace(/[^\w\s]/gi, "").trim(); - $(this).find("span").addClass("checked"); - if (collapsedSections.includes(menuName) == true && collapseAdded && sessionStorage.getItem(menuName) !== "expand" || sessionStorage.getItem(menuName) == "collapse") { - $(this.firstChild).after("[ + ]"); - $(this.firstChild).after("[ - ]"); - $(this).next("ul").hide(); - } else if (collapsedSections.includes(menuName) == false && collapseAdded || sessionStorage.getItem(menuName) == "expand") { - $(this.firstChild).after("[ + ]"); - $(this.firstChild).after("[ - ]"); - } - }); - - $(".expand-menu").on("click", function () { - $(this).prev(".hide-menu").toggle(); - $(this).parent().next("ul").toggle(); - var menuName = $(this).parent().text().replace(/[^\w\s]/gi, "").trim(); - if (sessionStorage.getItem(menuName) == "collapse") { - sessionStorage.removeItem(menuName); - } - sessionStorage.setItem(menuName, "expand"); - toggleList(this); - }); - - $(".hide-menu").on("click", function () { - $(this).next(".expand-menu").toggle(); - $(this).parent().next("ul").toggle(); - var menuName = $(this).parent().text().replace(/[^\w\s]/gi, "").trim(); - if (sessionStorage.getItem(menuName) == "expand") { - sessionStorage.removeItem(menuName); - } - sessionStorage.setItem(menuName, "collapse"); - toggleList(this); - }); - - function toggleList(menuCommand) { - $(menuCommand).toggle(); - } - }); - - // Build an array from each tag that's present - - var tagList = $(".tutorials-card-container").map(function() { - return $(this).data("tags").split(",").map(function(item) { - return item.trim(); - }); - }).get(); - - function unique(value, index, self) { - return self.indexOf(value) == index && value != "" - } - - // Only return unique tags - - var tags = tagList.sort().filter(unique); - - // Add filter buttons to the top of the page for each tag - - function createTagMenu() { - tags.forEach(function(item){ - $(".tutorial-filter-menu").append("
" + item + "
") - }) - }; - - createTagMenu(); - - // Remove hyphens if they are present in the filter buttons - - $(".tags").each(function(){ - var tags = $(this).text().split(","); - tags.forEach(function(tag, i ) { - tags[i] = tags[i].replace(/-/, ' ') - }) - $(this).html(tags.join(", ")); - }); - - // Remove hyphens if they are present in the card body - - $(".tutorial-filter").each(function(){ - var tag = $(this).text(); - $(this).html(tag.replace(/-/, ' ')) - }) - - // Remove any empty p tags that Sphinx adds - - $("#tutorial-cards p").each(function(index, item) { - if(!$(item).text().trim()) { - $(item).remove(); - } - }); - - // Jump back to top on pagination click - - $(document).on("click", ".page", function() { - $('html, body').animate( - {scrollTop: $("#dropdown-filter-tags").position().top}, - 'slow' - ); - }); - - var link = $("a[href='intermediate/speech_command_recognition_with_torchaudio.html']"); - - if (link.text() == "SyntaxError") { - console.log("There is an issue with the intermediate/speech_command_recognition_with_torchaudio.html menu item."); - link.text("Speech Command Recognition with torchaudio"); - } - - $(".stars-outer > i").hover(function() { - $(this).prevAll().addBack().toggleClass("fas star-fill"); - }); - - $(".stars-outer > i").on("click", function() { - $(this).prevAll().each(function() { - $(this).addBack().addClass("fas star-fill"); - }); - - $(".stars-outer > i").each(function() { - $(this).unbind("mouseenter mouseleave").css({ - "pointer-events": "none" - }); - }); - }) - - $("#pytorch-side-scroll-right li a").on("click", function (e) { - var href = $(this).attr("href"); - $('html, body').stop().animate({ - scrollTop: $(href).offset().top - 100 - }, 850); - e.preventDefault; - }); - - var lastId, - topMenu = $("#pytorch-side-scroll-right"), - topMenuHeight = topMenu.outerHeight() + 1, - // All sidenav items - menuItems = topMenu.find("a"), - // Anchors for menu items - scrollItems = menuItems.map(function () { - var item = $(this).attr("href"); - if (item.length) { - return item; - } - }); - - $(window).scroll(function () { - var fromTop = $(this).scrollTop() + topMenuHeight; - var article = ".section"; - - $(article).each(function (i) { - var offsetScroll = $(this).offset().top - $(window).scrollTop(); - if ( - offsetScroll <= topMenuHeight + 200 && - offsetScroll >= topMenuHeight - 200 && - scrollItems[i] == "#" + $(this).attr("id") && - $(".hidden:visible") - ) { - $(menuItems).removeClass("side-scroll-highlight"); - $(menuItems[i]).addClass("side-scroll-highlight"); - } - }); - }); - - },{"jquery":"jquery"}]},{},[1,2,3,4,5,6,7,8,9,10,11]); - -},{"jquery":"jquery"}]},{},[1,2,3,4,5,6,7,8,9,10,11]); - -},{"jquery":"jquery"}]},{},[1,2,3,4,5,6,7,8,9,10,11]); - -},{"jquery":"jquery"}]},{},[1,2,3,4,5,6,7,8,9,10,11]); +},{"jquery":"jquery"}]},{},[1,2,3,4,5,6,7,8,9,10,"pytorch-sphinx-theme"]); diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index e395300ef19..079e5877654 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -218,7 +218,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. Utils ----- -.. currentmodule:: torchrl.data.datasets +.. currentmodule:: torchrl.data .. autosummary:: :toctree: generated/ diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 8b661bfa391..430dea36996 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -114,7 +114,6 @@ provides more information on how to design a custom environment from scratch. EnvBase GymLikeEnv EnvMetaData - Specs Vectorized envs --------------- diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 7a52329e02f..fb1eebf6b89 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -32,7 +32,7 @@ TensorDict modules Hooks ----- -.. currentmodule:: torchrl.modules.tensordict_module.actors +.. currentmodule:: torchrl.modules .. autosummary:: :toctree: generated/ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index ba91adc2f5e..384117de4c9 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -16,13 +16,15 @@ The main characteristics of TorchRL losses are: method will receive a tensordict as input that contains all the necessary information to return a loss value. - They output a :class:`tensordict.TensorDict` instance with the loss values - written under a ``"loss_`` where ``smth`` is a string describing the + written under a ``"loss_"`` where ``smth`` is a string describing the loss. Additional keys in the tensordict may be useful metrics to log during training time. .. note:: The reason we return independent losses is to let the user use a different optimizer for different sets of parameters for instance. Summing the losses - can be simply done via ``sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")``. + can be simply done via + + >>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")) Training value functions ------------------------ @@ -216,5 +218,5 @@ Utils next_state_value SoftUpdate HardUpdate - ValueFunctions + ValueEstimators default_value_kwargs diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index a0c0056f2f7..d14cfae12ee 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -73,7 +73,7 @@ Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process" - **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept a :obj:`TensorDict` object as input and update it given some strategy. Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization - constants update), data subsampling (:doc:`BatchSubSampler`) and such. + constants update), data subsampling (:class:`torchrl.trainers.BatchSubSampler`) and such. - **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward diff --git a/test/test_trainer.py b/test/test_trainer.py index 533fd4f0b0d..9520a30e246 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -89,11 +89,10 @@ class MockingLossModule(nn.Module): def mocking_trainer(file=None, optimizer=_mocking_optim) -> Trainer: trainer = Trainer( - MockingCollector(), - *[ - None, - ] - * 2, + collector=MockingCollector(), + total_frames=None, + frame_skip=None, + optim_steps_per_batch=None, loss_module=MockingLossModule(), optimizer=optimizer, save_trainer_file=file, @@ -862,7 +861,7 @@ def test_recorder(self, N=8): with tempfile.TemporaryDirectory() as folder: logger = TensorboardLogger(exp_name=folder) - recorder = transformed_env_constructor( + environment = transformed_env_constructor( args, video_tag="tmp", norm_obs_only=True, @@ -874,7 +873,7 @@ def test_recorder(self, N=8): record_frames=args.record_frames, frame_skip=args.frame_skip, policy_exploration=None, - recorder=recorder, + environment=environment, record_interval=args.record_interval, ) trainer = mocking_trainer() @@ -936,7 +935,7 @@ def _make_recorder_and_trainer(tmpdirname): raise NotImplementedError trainer = mocking_trainer(file) - recorder = transformed_env_constructor( + environment = transformed_env_constructor( args, video_tag="tmp", norm_obs_only=True, @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname): record_frames=args.record_frames, frame_skip=args.frame_skip, policy_exploration=None, - recorder=recorder, + environment=environment, record_interval=args.record_interval, ) recorder.register(trainer) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 788a2cce27d..fa26ce0c6a9 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from . import datasets from .postprocs import MultiStep from .replay_buffers import ( LazyMemmapStorage, diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index 6fcc35a0d46..81a668648d0 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -1 +1,2 @@ from .d4rl import D4RLExperienceReplay +from .openml import OpenMLExperienceReplay diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index 78b90793682..76ccb66f601 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -8,8 +8,13 @@ import numpy as np from tensordict.tensordict import TensorDict -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.replay_buffers import Sampler, SamplerWithoutReplacement, Writer +from torchrl.data.replay_buffers import ( + LazyMemmapStorage, + Sampler, + SamplerWithoutReplacement, + TensorDictReplayBuffer, + Writer, +) class OpenMLExperienceReplay(TensorDictReplayBuffer): diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 2ec0bfb4d97..21f51115d6c 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -82,9 +82,9 @@ def _get_reward( class MultiStep(nn.Module): """Multistep reward transform. - Presented in 'Sutton, R. S. 1988. Learning to - predict by the methods of temporal differences. Machine learning 3( - 1):9–44.' + Presented in + + | Sutton, R. S. 1988. Learning to predict by the methods of temporal differences. Machine learning 3(1):9–44. This module maps the "next" observation to the t + n "next" observation. It is an identity transform whenever :attr:`n_steps` is 0. @@ -153,6 +153,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """ tensordict = tensordict.clone(False) done = tensordict.get(("next", "done")) + truncated = tensordict.get( + ("next", "truncated"), torch.zeros((), dtype=done.dtype, device=done.device) + ) + done = done | truncated # we'll be using the done states to index the tensordict. # if the shapes don't match we're in trouble. @@ -175,10 +179,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "(trailing singleton dimension excluded)." ) from err - truncated = tensordict.get( - ("next", "truncated"), torch.zeros((), dtype=done.dtype, device=done.device) - ) - done = done | truncated mask = tensordict.get(("collector", "mask"), None) reward = tensordict.get(("next", "reward")) *batch, T = tensordict.batch_size diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index fb86b0cec06..0c774014f40 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -11,7 +11,7 @@ import torch from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from tensordict.utils import expand_right +from tensordict.utils import expand_as_right from torchrl.data.utils import DEVICE_TYPING @@ -708,6 +708,8 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: return index def update_tensordict_priority(self, data: TensorDictBase) -> None: + if not isinstance(self._sampler, PrioritizedSampler): + return priority = torch.tensor( [self._get_priority(td) for td in data], dtype=torch.float, @@ -753,19 +755,7 @@ def sample( data, info = super().sample(batch_size, return_info=True) if include_info in (True, None): for k, v in info.items(): - data.set(k, torch.tensor(v, device=data.device)) - if "_batch_size" in data.keys(): - # we need to reset the batch-size - shape = data.pop("_batch_size") - shape = shape[0] - shape = torch.Size([data.shape[0], *shape]) - # we may need to update some values in the data - for key, value in data.items(): - if value.ndim >= len(shape): - continue - value = expand_right(value, shape) - data.set(key, value) - data.batch_size = shape + data.set(k, expand_as_right(torch.tensor(v, device=data.device), data)) if return_info: return data, info return data diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 7a789260e48..d96e2498f6b 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -14,6 +14,7 @@ from tensordict.memmap import MemmapTensor from tensordict.prototype import is_tensorclass from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase +from tensordict.utils import expand_right from torchrl._utils import _CKPT_BACKEND, VERBOSE from torchrl.data.replay_buffers.utils import INT_CLASSES @@ -423,10 +424,42 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: return mem_map_tensor._tensor +def _reset_batch_size(x): + """Resets the batch size of a tensordict. + + In some cases we save the original shape of the tensordict as a tensor (or memmap tensor). + + This function will read that tensor, extract its items and reset the shape + of the tensordict to it. If items have an incompatible shape (e.g. "index") + they will be expanded to the right to match it. + + """ + shape = x.pop("_batch_size", None) + if shape is not None: + # we need to reset the batch-size + if isinstance(shape, MemmapTensor): + shape = shape.as_tensor() + locked = x.is_locked + if locked: + x.unlock_() + shape = [s.item() for s in shape[0]] + shape = torch.Size([x.shape[0], *shape]) + # we may need to update some values in the data + for key, value in x.items(): + if value.ndim >= len(shape): + continue + value = expand_right(value, shape) + x.set(key, value) + x.batch_size = shape + if locked: + x.lock_() + return x + + def _collate_list_tensordict(x): out = torch.stack(x, 0) if isinstance(out, TensorDictBase): - return out.to_tensordict() + return _reset_batch_size(out.to_tensordict()) return out @@ -436,7 +469,7 @@ def _collate_list_tensors(*x): def _collate_contiguous(x): if isinstance(x, TensorDictBase): - return x.to_tensordict() + return _reset_batch_size(x).to_tensordict() return x.clone() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6a0dd6be2b8..08f9dfe5c46 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2602,6 +2602,13 @@ class VecNorm(Transform): default: 0.99 eps (number, optional): lower bound of the running standard deviation (for numerical underflow). Default is 1e-4. + shapes (List[torch.Size], optional): if provided, represents the shape + of each in_keys. Its length must match the one of ``in_keys``. + Each shape must match the trailing dimension of the corresponding + entry. + If not, the feature dimensions of the entry (ie all dims that do + not belong to the tensordict batch-size) will be considered as + feature dimension. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -2629,6 +2636,7 @@ def __init__( lock: mp.Lock = None, decay: float = 0.9999, eps: float = 1e-4, + shapes: List[torch.Size] = None, ) -> None: if lock is None: lock = mp.Lock() @@ -2656,8 +2664,14 @@ def __init__( self.lock = lock self.decay = decay + self.shapes = shapes self.eps = eps + def _key_str(self, key): + if not isinstance(key, str): + key = "_".join(key) + return key + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.lock is not None: self.lock.acquire() @@ -2681,17 +2695,44 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: forward = _call def _init(self, tensordict: TensorDictBase, key: str) -> None: - if self._td is None or key + "_sum" not in self._td.keys(): - td_view = tensordict.view(-1) - td_select = td_view[0] - d = {key + "_sum": torch.zeros_like(td_select.get(key))} - d.update({key + "_ssq": torch.zeros_like(td_select.get(key))}) + key_str = self._key_str(key) + if self._td is None or key_str + "_sum" not in self._td.keys(): + if key is not key_str and key_str in tensordict.keys(): + raise RuntimeError( + f"Conflicting key names: {key_str} from VecNorm and input tensordict keys." + ) + if self.shapes is None: + td_view = tensordict.view(-1) + td_select = td_view[0] + item = td_select.get(key) + d = {key_str + "_sum": torch.zeros_like(item)} + d.update({key_str + "_ssq": torch.zeros_like(item)}) + else: + idx = 0 + for in_key in self.in_keys: + if in_key != key: + idx += 1 + else: + break + shape = self.shapes[idx] + item = tensordict.get(key) + d = { + key_str + + "_sum": torch.zeros(shape, device=item.device, dtype=item.dtype) + } + d.update( + { + key_str + + "_ssq": torch.zeros( + shape, device=item.device, dtype=item.dtype + ) + } + ) + d.update( { - key - + "_count": torch.zeros( - 1, device=td_select.get(key).device, dtype=torch.float - ) + key_str + + "_count": torch.zeros(1, device=item.device, dtype=torch.float) } ) if self._td is None: @@ -2702,6 +2743,7 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None: pass def _update(self, key, value, N) -> torch.Tensor: + key = self._key_str(key) _sum = self._td.get(key + "_sum") _ssq = self._td.get(key + "_ssq") _count = self._td.get(key + "_count") diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 5a3f4fdbb2b..7c26b7b1b8f 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -41,10 +41,12 @@ ActorValueOperator, AdditiveGaussianWrapper, DistributionalQValueActor, + DistributionalQValueHook, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, QValueActor, + QValueHook, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 6686eb6b602..d74634c153a 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -9,8 +9,10 @@ ActorCriticWrapper, ActorValueOperator, DistributionalQValueActor, + DistributionalQValueHook, ProbabilisticActor, QValueActor, + QValueHook, ValueOperator, ) from .common import SafeModule diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 635fc90ca21..7b9b8ef53a1 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -715,7 +715,8 @@ def __init__( class ActorValueOperator(SafeSequential): """Actor-value operator. - This class wraps together an actor and a value model that share a common observation embedding network: + This class wraps together an actor and a value model that share a common + observation embedding network: .. aafig:: :aspect: 60 @@ -723,22 +724,30 @@ class ActorValueOperator(SafeSequential): :proportional: :textual: - +-------------+ - |"Observation"| - +-------------+ - | - v - +--------------+ - |"hidden state"| - +--------------+ - | | | - v | v - actor | critic - | | | - v | v - +--------+|+-------+ - |"action"|||"value"| - +--------+|+-------+ + +---------------+ + |Observation (s)| + +---------------+ + | + v + common + | + v + +------------------+ + | Hidden state | + +------------------+ + | | + v v + actor critic + | | + v v + +-------------+ +------------+ + |Action (a(s))| |Value (V(s))| + +-------------+ +------------+ + + .. note:: + For a similar class that returns an action and a Quality value :math:`Q(s, a)` + see :class:`~.ActorCriticOperator`. For a version without common embeddig + refet to :class:`~.ActorCriticWrapper`. To facilitate the workflow, this class comes with a get_policy_operator() and get_value_operator() methods, which will both return a stand-alone TDModule with the dedicated functionality. @@ -755,17 +764,13 @@ class ActorValueOperator(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor, SafeModule - >>> from torchrl.data import UnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper - >>> spec_hidden = UnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, - ... spec=spec_hidden, ... in_keys=["observation"], ... out_keys=["hidden"], ... ) - >>> spec_action = BoundedTensorSpec(-1, 1, torch.Size([8])) >>> module_action = TensorDictModule( ... NormalParamWrapper(torch.nn.Linear(4, 8)), ... in_keys=["hidden"], @@ -773,7 +778,6 @@ class ActorValueOperator(SafeSequential): ... ) >>> td_module_action = ProbabilisticActor( ... module=module_action, - ... spec=spec_action, ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=TanhNormal, @@ -854,7 +858,8 @@ def get_value_operator(self) -> SafeSequential: class ActorCriticOperator(ActorValueOperator): """Actor-critic operator. - This class wraps together an actor and a value model that share a common observation embedding network: + This class wraps together an actor and a value model that share a common + observation embedding network: .. aafig:: :aspect: 60 @@ -862,51 +867,58 @@ class ActorCriticOperator(ActorValueOperator): :proportional: :textual: - +-----------+ - |Observation| - +-----------+ - | - v - actor - | - v - +------+ - |action| --> critic - +------+ | - v - +-----+ - |value| - +-----+ + +---------------+ + |Observation (s)| + +---------------+ + | + v + common + | + v + +------------------+ + | Hidden state | + +------------------+ + | | + v v + actor ------> critic + | | + v v + +-------------+ +----------------+ + |Action (a(s))| |Quality (Q(s,a))| + +-------------+ +----------------+ + + .. note:: + For a similar class that returns an action and a state-value :math:`V(s)` + see :class:`~.ActorValueOperator`. + To facilitate the workflow, this class comes with a get_policy_operator() method, which will both return a stand-alone TDModule with the dedicated functionality. The get_critic_operator will return the parent object, as the value is computed based on the policy output. Args: - common_operator (TensorDictModule): a common operator that reads observations and produces a hidden variable - policy_operator (TensorDictModule): a policy operator that reads the hidden variable and returns an action - value_operator (TensorDictModule): a value operator, that reads the hidden variable and returns a value + common_operator (TensorDictModule): a common operator that reads + observations and produces a hidden variable + policy_operator (TensorDictModule): a policy operator that reads the + hidden variable and returns an action + value_operator (TensorDictModule): a value operator, that reads the + hidden variable and returns a value Examples: >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor - >>> from torchrl.data import UnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP - >>> spec_hidden = UnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, - ... spec=spec_hidden, ... in_keys=["observation"], ... out_keys=["hidden"], ... ) - >>> spec_action = BoundedTensorSpec(-1, 1, torch.Size([8])) >>> module_action = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> module_action = TensorDictModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) >>> td_module_action = ProbabilisticActor( ... module=module_action, - ... spec=spec_action, ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=TanhNormal, @@ -964,8 +976,17 @@ class ActorCriticOperator(ActorValueOperator): """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + common_operator: TensorDictModule, + policy_operator: TensorDictModule, + value_operator: TensorDictModule, + ): + super().__init__( + common_operator, + policy_operator, + value_operator, + ) if self[2].out_keys[0] == "state_value": raise RuntimeError( "Value out_key is state_value, which may lead to errors in downstream usages" @@ -998,17 +1019,18 @@ class ActorCriticWrapper(SafeSequential): :proportional: :textual: - +-----------+ - |Observation| - +-----------+ - | | | - v | v - actor | critic - | | | - v | v - +------+|+-------+ - |action||| value | - +------+|+-------+ + +---------------+ + |Observation (s)| + +---------------+ + | | | + v | v + actor | critic + | | | + v | v + +-------------+|+------------+ + |Action (a(s))|||Value (V(s))| + +-------------+|+------------+ + To facilitate the workflow, this class comes with a get_policy_operator() and get_value_operator() methods, which will both return a stand-alone TDModule with the dedicated functionality. @@ -1021,7 +1043,6 @@ class ActorCriticWrapper(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule - >>> from torchrl.data import UnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ( ... ActorCriticWrapper, ... ProbabilisticActor, @@ -1029,7 +1050,6 @@ class ActorCriticWrapper(SafeSequential): ... TanhNormal, ... ValueOperator, ... ) - >>> action_spec = BoundedTensorSpec(-1, 1, torch.Size([8])) >>> action_module = TensorDictModule( ... NormalParamWrapper(torch.nn.Linear(4, 8)), ... in_keys=["observation"], @@ -1037,7 +1057,6 @@ class ActorCriticWrapper(SafeSequential): ... ) >>> td_module_action = ProbabilisticActor( ... module=action_module, - ... spec=action_spec, ... in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... return_log_prob=True, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 770d3f3e406..5c4bc835e5c 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -102,6 +102,10 @@ def convert_to_functional( params = make_functional(module, funs_to_decorate=funs_to_decorate) functional_module = deepcopy(module) repopulate_module(module, params) + # params = make_functional( + # module, funs_to_decorate=funs_to_decorate, keep_params=True + # ) + # functional_module = module params_and_buffers = params # we transform the buffers in params to make sure they follow the device @@ -280,7 +284,8 @@ def _target_param_getter(self, network_name): value_to_set = getattr( self, "_sep_".join(["_target_" + network_name, *key]) ) - target_params.set(key, value_to_set) + # _set is faster bc is bypasses the checks + target_params._set(key, value_to_set) return target_params else: params = getattr(self, param_name) @@ -392,7 +397,7 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): this method. Args: - value_type (ValueEstimators): A :class:`torchrl.objectives.utils.ValueFunctions` + value_type (ValueEstimators): A :class:`torchrl.objectives.utils.ValueEstimators` enum type indicating the value function to use. **hyperparams: hyperparameters to use for the value function. If not provided, the value indicated by diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index c1cacd7349e..917f5df44c6 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -99,12 +99,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: a tuple of 2 tensors containing the DDPG loss. """ - if not input_tensordict.device == self.device: - raise RuntimeError( - f"Got device={input_tensordict.device} but " - f"actor_network.device={self.device} (self.device={self.device})" - ) - loss_value, td_error, pred_val, target_value = self._loss_value( input_tensordict, ) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index e584b894ed7..70957785fa7 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -189,10 +189,12 @@ class DistributionalDQNLoss(LossModule): value_network (DistributionalQValueActor or nn.Module): the distributional Q value operator. gamma (scalar): a discount factor for return computation. + .. note:: Unlike :class:`DQNLoss`, this class does not currently support custom value functions. The next value estimation is always bootstrapped. + delay_value (bool): whether to duplicate the value network into a new target value network to create double DQN """ diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 0503ecffb25..2d8498286a0 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -49,9 +49,11 @@ class SACLoss(LossModule): This module typically outputs a ``"state_action_value"`` entry. value_network (TensorDictModule, optional): V(s) parametric model. This module typically outputs a ``"state_value"`` entry. + .. note:: If not provided, the second version of SAC is assumed, where only the Q-Value network is needed. + num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. priority_key (str, optional): tensordict key where to write the diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 3daf5e70876..3af554935a9 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -18,7 +18,7 @@ _GAMMA_LMBDA_DEPREC_WARNING = ( "Passing gamma / lambda parameters through the loss constructor " "is deprecated and will be removed soon. To customize your value function, " - "run `loss_module.make_value_estimator(ValueFunctions., gamma=val)`." + "run `loss_module.make_value_estimator(ValueEstimators., gamma=val)`." ) @@ -45,7 +45,7 @@ def default_value_kwargs(value_type: ValueEstimators): Args: value_type (Enum.value): the value function type, from the - :class:`torchrl.objectives.utils.ValueFunctions` class. + :class:`torchrl.objectives.utils.ValueEstimators` class. Examples: >>> kwargs = default_value_kwargs(ValueEstimators.TDLambda) @@ -242,15 +242,18 @@ def __repr__(self) -> str: class SoftUpdate(TargetNetUpdater): - """A soft-update class for target network update in Double DQN/DDPG. + r"""A soft-update class for target network update in Double DQN/DDPG. This was proposed in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf Args: loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated. eps (scalar): epsilon in the update equation: - param = prev_param * eps + new_param * (1-eps) - default: 0.999 + .. math:: + + \theta_t = \theta_{t-1} * \epsilon + \theta_t * (1-\epsilon) + + Defaults to 0.999 """ def __init__( @@ -264,7 +267,7 @@ def __init__( ], eps: float = 0.999, ): - if not (eps < 1.0 and eps > 0.0): + if not (eps <= 1.0 and eps >= 0.0): raise ValueError( f"Got eps = {eps} when it was supposed to be between 0 and 1." ) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 14799118990..e6e42fef55f 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -132,10 +132,12 @@ class TD0Estimator(ValueEstimatorBase): before the TD is computed. differentiable (bool, optional): if ``True``, gradients are propagated through the computation of the value function. Default is ``False``. + .. note:: The proper way to make the function call non-differentiable is to decorate it in a `torch.no_grad()` context manager/decorator or pass detached parameters for functional modules. + advantage_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. @@ -319,10 +321,12 @@ class TD1Estimator(ValueEstimatorBase): before the TD is computed. differentiable (bool, optional): if ``True``, gradients are propagated through the computation of the value function. Default is ``False``. + .. note:: The proper way to make the function call non-differentiable is to decorate it in a `torch.no_grad()` context manager/decorator or pass detached parameters for functional modules. + advantage_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. @@ -506,10 +510,12 @@ class TDLambdaEstimator(ValueEstimatorBase): before the TD is computed. differentiable (bool, optional): if ``True``, gradients are propagated through the computation of the value function. Default is ``False``. + .. note:: The proper way to make the function call non-differentiable is to decorate it in a `torch.no_grad()` context manager/decorator or pass detached parameters for functional modules. + vectorized (bool, optional): whether to use the vectorized version of the lambda return. Default is `True`. advantage_key (str or tuple of str, optional): the key of the advantage entry. @@ -710,10 +716,12 @@ class GAE(ValueEstimatorBase): Default is ``False``. differentiable (bool, optional): if ``True``, gradients are propagated through the computation of the value function. Default is ``False``. + .. note:: The proper way to make the function call non-differentiable is to decorate it in a `torch.no_grad()` context manager/decorator or pass detached parameters for functional modules. + advantage_key (str or tuple of str, optional): the key of the advantage entry. Defaults to "advantage". value_target_key (str or tuple of str, optional): the key of the advantage entry. diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 69120bf1110..90aa41a742e 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -74,6 +74,7 @@ def __init__(self, exp_name: str, log_dir: Optional[str] = None) -> None: super().__init__(exp_name=exp_name, log_dir=log_dir) self._has_imported_moviepy = False + print(f"self.log_dir: {self.experiment.log_dir}") def _create_experiment(self) -> "CSVExperiment": """Creates a CSV experiment.""" diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index cbd1a66cb77..69f33b796de 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -22,6 +22,7 @@ from torchrl._utils import _CKPT_BACKEND, KeyDependentDefaultDict, VERBOSE from torchrl.collectors.collectors import DataCollectorBase +from torchrl.collectors.utils import split_trajectories from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase @@ -70,6 +71,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @abc.abstractmethod def register(self, trainer: Trainer, name: str): + """Registers the hook in the trainer at a default location. + + Args: + trainer (Trainer): the trainer where the hook must be registered. + name (str): the name of the hook. + + .. note:: + To register the hook at another location than the default, use + :meth:`torchrl.trainers.Trainer.register_op`. + + """ raise NotImplementedError @@ -95,24 +107,25 @@ class Trainer: optimizer (optim.Optimizer): An optimizer that trains the parameters of the model. logger (Logger, optional): a Logger that will handle the logging. - optim_steps_per_batch (int, optional): number of optimization steps + optim_steps_per_batch (int): number of optimization steps per collection of data. An trainer works as follows: a main loop collects batches of data (epoch loop), and a sub-loop (training loop) performs model updates in between two collections of data. - Default is 500 clip_grad_norm (bool, optional): If True, the gradients will be clipped based on the total norm of the model parameters. If False, all the partial derivatives will be clamped to (-clip_norm, clip_norm). Default is :obj:`True`. clip_norm (Number, optional): value to be used for clipping gradients. - Default is 100.0. + Default is None (no clip norm). progress_bar (bool, optional): If True, a progress bar will be displayed using tqdm. If tqdm is not installed, this option won't have any effect. Default is :obj:`True` seed (int, optional): Seed to be used for the collector, pytorch and - numpy. Default is 42. + numpy. Default is ``None``. save_trainer_interval (int, optional): How often the trainer should be - saved to disk. Default is 10000. + saved to disk, in frame count. Default is 10000. + log_interval (int, optional): How often the values should be logged, + in frame count. Default is 10000. save_trainer_file (path, optional): path where to save the trainer. Default is None (no saving) """ @@ -124,25 +137,26 @@ def __new__(cls, *args, **kwargs): cls._collected_frames: int = 0 cls._last_log: Dict[str, Any] = {} cls._last_save: int = 0 - cls._log_interval: int = 10000 cls.collected_frames = 0 cls._app_state = None return super().__new__(cls) def __init__( self, + *, collector: DataCollectorBase, total_frames: int, frame_skip: int, + optim_steps_per_batch: int, loss_module: Union[LossModule, Callable[[TensorDictBase], TensorDictBase]], optimizer: Optional[optim.Optimizer] = None, logger: Optional[Logger] = None, - optim_steps_per_batch: int = 500, clip_grad_norm: bool = True, - clip_norm: float = 100.0, + clip_norm: float = None, progress_bar: bool = True, - seed: int = 42, + seed: int = None, save_trainer_interval: int = 10000, + log_interval: int = 10000, save_trainer_file: Optional[Union[str, pathlib.Path]] = None, ) -> None: @@ -153,9 +167,12 @@ def __init__( self.optimizer = optimizer self.logger = logger + self._log_interval = log_interval + # seeding self.seed = seed - self.set_seed() + if seed is not None: + self.set_seed() # constants self.optim_steps_per_batch = optim_steps_per_batch @@ -421,7 +438,6 @@ def train(self): for batch in self.collector: batch = self._process_batch_hook(batch) - self._pre_steps_log_hook(batch) current_frames = ( batch.get(("collector", "mask"), torch.tensor(batch.numel())) .sum() @@ -429,6 +445,7 @@ def train(self): * self.frame_skip ) self.collected_frames += current_frames + self._pre_steps_log_hook(batch) if self.collected_frames > self.collector.init_random_frames: self.optim_steps(batch) @@ -489,7 +506,6 @@ def _log(self, log_pbar=False, **kwargs) -> None: collected_frames = self.collected_frames for key, item in kwargs.items(): self._log_dict[key].append(item) - if (collected_frames - self._last_log.get(key, 0)) > self._log_interval: self._last_log[key] = collected_frames _log = True @@ -601,8 +617,10 @@ class ReplayBufferTrainer(TrainerHookBase): Args: replay_buffer (TensorDictReplayBuffer): replay buffer to be used. - batch_size (int): batch size when sampling data from the - latest collection or from the replay buffer. + batch_size (int, optional): batch size when sampling data from the + latest collection or from the replay buffer. If none is provided, + the replay buffer batch-size will be used (preferred option for + unchanged batch-sizes). memmap (bool, optional): if ``True``, a memmap tensordict is created. Default is False. device (device, optional): device where the samples must be placed. @@ -630,7 +648,7 @@ class ReplayBufferTrainer(TrainerHookBase): def __init__( self, replay_buffer: TensorDictReplayBuffer, - batch_size: int, + batch_size: Optional[int] = None, memmap: bool = False, device: DEVICE_TYPING = "cpu", flatten_tensordicts: bool = True, @@ -640,6 +658,12 @@ def __init__( self.batch_size = batch_size self.memmap = memmap self.device = device + if flatten_tensordicts: + warnings.warn( + "flatten_tensordicts default value will soon be changed " + "to False for a faster execution. Make sure your " + "code is robust to this change." + ) self.flatten_tensordicts = flatten_tensordicts self.max_dims = max_dims @@ -668,7 +692,7 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase: self.replay_buffer.extend(batch) def sample(self, batch: TensorDictBase) -> TensorDictBase: - sample = self.replay_buffer.sample(self.batch_size) + sample = self.replay_buffer.sample(batch_size=self.batch_size) return sample.to(self.device, non_blocking=True) def update_priority(self, batch: TensorDictBase) -> None: @@ -726,11 +750,12 @@ def _grad_clip(self, clip_grad_norm: bool, clip_norm: float) -> float: for param_group in self.optimizer.param_groups: params += param_group["params"] - if clip_grad_norm: + if clip_grad_norm and clip_norm is not None: gn = nn.utils.clip_grad_norm_(params, clip_norm) else: gn = sum([p.grad.pow(2).sum() for p in params if p.grad is not None]).sqrt() - nn.utils.clip_grad_value_(params, clip_norm) + if clip_norm is not None: + nn.utils.clip_grad_value_(params, clip_norm) return float(gn) @@ -1093,7 +1118,7 @@ def register(self, trainer: Trainer, name: str = "batch_subsampler"): class Recorder(TrainerHookBase): - """Recorder hook for Trainer. + """Recorder hook for :class:`torchrl.trainers.Trainer`. Args: record_interval (int): total number of optimisation steps @@ -1105,7 +1130,7 @@ class Recorder(TrainerHookBase): each iteration, otherwise the frame count can be underestimated. For logging, this parameter is important to normalize the reward. Finally, to compare different runs with different frame_skip, - one must normalize the frame count and rewards. Default is 1. + one must normalize the frame count and rewards. Defaults to ``1``. policy_exploration (ProbabilisticTDModule): a policy instance used for @@ -1117,35 +1142,48 @@ class Recorder(TrainerHookBase): the performance of the policy, it should be possible to turn off the explorative behaviour by calling the `set_exploration_mode('mode')` context manager. - recorder (EnvBase): An environment instance to be used + environment (EnvBase): An environment instance to be used for testing. exploration_mode (str, optional): exploration mode to use for the policy. By default, no exploration is used and the value used is "mode". Set to "random" to enable exploration - out_key (str, optional): reward key to set to the logger. Default is - `"reward_evaluation"`. + log_keys (sequence of str or tuples or str, optional): keys to read in the tensordict + for logging. Defaults to ``[("next", "reward")]``. + out_keys (Dict[str, str], optional): a dictionary mapping the ``log_keys`` + to their name in the logs. Defaults to ``{("next", "reward"): "r_evaluation"}``. suffix (str, optional): suffix of the video to be recorded. log_pbar (bool, optional): if ``True``, the reward value will be logged on the progression bar. Default is `False`. """ + ENV_DEPREC = ( + "the environment should be passed under the 'environment' key" + " and not the 'recorder' key." + ) + def __init__( self, + *, record_interval: int, record_frames: int, - frame_skip: int, + frame_skip: int = 1, policy_exploration: TensorDictModule, - recorder: EnvBase, + environment: EnvBase = None, exploration_mode: str = "random", - log_keys: Optional[List[str]] = None, - out_keys: Optional[Dict[str, str]] = None, + log_keys: Optional[List[Union[str, Tuple[str]]]] = None, + out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, suffix: Optional[str] = None, log_pbar: bool = False, + recorder: EnvBase = None, ) -> None: - + if environment is None and recorder is not None: + warnings.warn(self.ENV_DEPREC) + environment = recorder + elif environment is not None and recorder is not None: + raise ValueError("environment and recorder conflict.") self.policy_exploration = policy_exploration - self.recorder = recorder + self.environment = environment self.record_frames = record_frames self.frame_skip = frame_skip self._count = 0 @@ -1168,43 +1206,45 @@ def __call__(self, batch: TensorDictBase) -> Dict: with set_exploration_mode(self.exploration_mode): if isinstance(self.policy_exploration, torch.nn.Module): self.policy_exploration.eval() - self.recorder.eval() - td_record = self.recorder.rollout( + self.environment.eval() + td_record = self.environment.rollout( policy=self.policy_exploration, max_steps=self.record_frames, auto_reset=True, auto_cast_to_device=True, break_when_any_done=False, ).clone() + td_record = split_trajectories(td_record) if isinstance(self.policy_exploration, torch.nn.Module): self.policy_exploration.train() - self.recorder.train() - self.recorder.transform.dump(suffix=self.suffix) + self.environment.train() + self.environment.transform.dump(suffix=self.suffix) out = {} for key in self.log_keys: value = td_record.get(key).float() if key == ("next", "reward"): - mean_value = value.mean() / self.frame_skip - total_value = value.sum() + mask = td_record["mask"] + mean_value = value[mask].mean() / self.frame_skip + total_value = value.sum(dim=td_record.ndim - 1).mean() out[self.out_keys[key]] = mean_value out["total_" + self.out_keys[key]] = total_value continue out[self.out_keys[key]] = value out["log_pbar"] = self.log_pbar self._count += 1 - self.recorder.close() + self.environment.close() return out def state_dict(self) -> Dict: return { "_count": self._count, - "recorder_state_dict": self.recorder.state_dict(), + "recorder_state_dict": self.environment.state_dict(), } def load_state_dict(self, state_dict: Dict) -> None: self._count = state_dict["_count"] - self.recorder.load_state_dict(state_dict["recorder_state_dict"]) + self.environment.load_state_dict(state_dict["recorder_state_dict"]) def register(self, trainer: Trainer, name: str = "recorder"): trainer.register_module(name, self) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 6b2c87b66c7..53a6ae10e47 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -1,28 +1,33 @@ # -*- coding: utf-8 -*- """ -Coding DDPG using TorchRL -========================= +TorchRL objectives: Coding a DDPG loss +====================================== **Author**: `Vincent Moens `_ """ + ############################################################################## -# This tutorial will guide you through the steps to code DDPG from scratch. +# TorchRL separates the training of RL algorithms in various pieces that will be +# assembled in your training script: the environment, the data collection and +# storage, the model and finally the loss function. +# +# TorchRL losses (or "objectives") are stateful objects that contain the +# trainable parameters (policy and value models). +# This tutorial will guide you through the steps to code a loss from the ground up +# using torchrl. # +# To this aim, we will be focusing on DDPG, which is a relatively straightforward +# algorithm to code. # DDPG (`Deep Deterministic Policy Gradient _`_) # is a simple continuous control algorithm. It consists in learning a # parametric value function for an action-observation pair, and # then learning a policy that outputs actions that maximise this value # function given a certain observation. # -# This tutorial is more than the PPO tutorial: it covers -# multiple topics that were left aside. We strongly advise the reader to go -# through the PPO tutorial first before trying out this one. The goal is to -# show how flexible torchrl is when it comes to writing scripts that can cover -# multiple use cases. -# # Key learnings: # -# - how to build an environment in TorchRL, including transforms +# - how to write a loss module and customize its value estimator; +# - how to build an environment in torchrl, including transforms # (e.g. data normalization) and parallel execution; # - how to design a policy and value network; # - how to collect data from your environment efficiently and store them @@ -30,67 +35,355 @@ # - how to store trajectories (and not transitions) in your replay buffer); # - and finally how to evaluate your model. # -# This tutorial assumes the reader is familiar with some of TorchRL primitives, -# such as :class:`tensordict.TensorDict` and -# :class:`tensordict.nn.TensorDictModules`, although it should be +# This tutorial assumes that you have completed the PPO tutorial which gives +# an overview of the torchrl components and dependencies, such as +# :class:`tensordict.TensorDict` and :class:`tensordict.nn.TensorDictModules`, +# although it should be # sufficiently transparent to be understood without a deep understanding of # these classes. # -# We do not aim at giving a SOTA implementation of the algorithm, but rather -# to provide a high-level illustration of TorchRL features in the context of -# this algorithm. +# .. note:: +# We do not aim at giving a SOTA implementation of the algorithm, but rather +# to provide a high-level illustration of torchrl's loss implementations +# and the library features that are to be used in the context of +# this algorithm. # -# Imports -# ------- +# Imports and setup +# ----------------- # # sphinx_gallery_start_ignore import warnings +from typing import Tuple warnings.filterwarnings("ignore") # sphinx_gallery_end_ignore -from copy import deepcopy - -import numpy as np -import torch import torch.cuda import tqdm -from matplotlib import pyplot as plt -from tensordict.nn import TensorDictModule -from torch import nn, optim -from torchrl.collectors import MultiaSyncDataCollector -from torchrl.data import CompositeSpec, TensorDictReplayBuffer -from torchrl.data.postprocs import MultiStep -from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler -from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.envs import ( - CatTensors, - DoubleToFloat, - EnvCreator, - ObservationNorm, - ParallelEnv, -) -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import RewardScaling, TransformedEnv -from torchrl.envs.utils import set_exploration_mode, step_mdp -from torchrl.modules import ( - MLP, - OrnsteinUhlenbeckProcessWrapper, - ProbabilisticActor, - ValueOperator, + + +############################################################################### +# We will execute the policy on cuda if available +device = ( + torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0") ) -from torchrl.modules.distributions.continuous import TanhDelta -from torchrl.objectives.utils import hold_out_net -from torchrl.trainers import Recorder ############################################################################### +# torchrl :class:`torchrl.objectives.LossModule` +# ---------------------------------------------- +# +# TorchRL provides a series of losses to use in your training scripts. +# The aim is to have losses that are easily reusable/swappable and that have +# a simple signature. +# +# The main characteristics of TorchRL losses are: +# +# - they are stateful objects: they contain a copy of the trainable parameters +# such that ``loss_module.parameters()`` gives whatever is needed to train the +# algorithm. +# - They follow the ``tensordict`` convention: the :meth:`torch.nn.Module.forward` +# method will receive a tensordict as input that contains all the necessary +# information to return a loss value. +# +# >>> data = replay_buffer.sample() +# >>> loss_dict = loss_module(data) +# +# - They output a :class:`tensordict.TensorDict` instance with the loss values +# written under a ``"loss_"`` where ``smth`` is a string describing the +# loss. Additional keys in the tensordict may be useful metrics to log during +# training time. +# .. note:: +# The reason we return independent losses is to let the user use a different +# optimizer for different sets of parameters for instance. Summing the losses +# can be simply done via +# +# >>> loss_val = sum(loss for key, loss in loss_dict.items() if key.startswith("loss_")) +# +# The ``__init__`` method +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# The parent class of all losses is :class:`torchrl.objectives.LossModule`. +# As many other components of the library, its :meth:`torchrl.objectives.LossModule.forward` method expects +# as input a :class:`tensordict.TensorDict` instance sampled from an experience +# replay buffer, or any similar data structure. Using this format makes it +# possible to re-use the module across +# modalities, or in complex settings where the model needs to read multiple +# entries for instance. In other words, it allows us to code a loss module that +# is oblivious to the data type that is being given to is and that focuses on +# running the elementary steps of the loss function and only those. +# +# To keep the tutorial as didactic as we can, we'll be displaying each method +# of the class independently and we'll be populating the class at a later +# stage. +# +# Let us start with the :meth:`torchrl.objectives.LossModule.__init__` +# method. DDPG aims at solving a control task with a simple strategy: +# training a policy to output actions that maximise the value predicted by +# a value network. Hence, our loss module needs to receive two networks in its +# constructor: an actor and a value networks. We expect both of these to be +# tensordict-compatible objects, such as +# :class:`tensordict.nn.TensorDictModule`. +# Our loss function will need to compute a target value and fit the value +# network to this, and generate an action and fit the policy such that its +# value estimate is maximised. +# +# The crucial step of the :meth:`LossModule.__init__` method is the call to +# :meth:`torchrl.LossModule.convert_to_functional`. This method will extract +# the parameters from the module and convert it to a functional module. +# Strictly speaking, this is not necessary and one may perfectly code all +# the losses without it. However, we encourage its usage for the following +# reason. +# +# The reason TorchRL does this is that RL algorithms often execute the same +# model with different sets of parameters, called "trainable" and "target" +# parameters. +# The "trainable" parameters are those that the optimizer needs to fit. The +# "target" parameters are usually a copy of the formers with some time lag +# (absolute or diluted through a moving average). +# These target parameters are used to compute the value associated with the +# next observation. One the advantages of using a set of target parameters +# for the value model that do not match exactly the current configuration is +# that they provide a pessimistic bound on the value function being computed. +# Pay attention to the ``create_target_params`` keyword argument below: this +# argument tells the :meth:`torchrl.objectives.LossModule.convert_to_functional` +# method to create a set of target parameters in the loss module to be used +# for target value computation. If this is set to ``False`` (see the actor network +# for instance) the ``target_actor_network_params`` attribute will still be +# accessible but this will just return a **detached** version of the +# actor parameters. +# +# Later, we will see how the target parameters should be updated in torchrl. +# + +from tensordict.nn import TensorDictModule + + +def _init( + self, + actor_network: TensorDictModule, + value_network: TensorDictModule, +) -> None: + super(type(self), self).__init__() + + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=True, + ) + self.convert_to_functional( + value_network, + "value_network", + create_target_params=True, + compare_against=list(actor_network.parameters()), + ) + + self.actor_in_keys = actor_network.in_keys + + # Since the value we'll be using is based on the actor and value network, + # we put them together in a single actor-critic container. + actor_critic = ActorCriticWrapper(actor_network, value_network) + self.actor_critic = actor_critic + self.loss_funtion = "l2" + + +############################################################################### +# The value estimator loss method +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In many RL algorithm, the value network (or Q-value network) is trained based +# on an empirical value estimate. This can be bootstrapped (TD(0), low +# variance, high bias), meaning +# that the target value is obtained using the next reward and nothing else, or +# a Monte-Carlo estimate can be obtained (TD(1)) in which case the whole +# sequence of upcoming rewards will be used (high variance, low bias). An +# intermediate estimator (TD(:math:`\lambda`)) can also be used to compromise +# bias and variance. +# TorchRL makes it easy to use one or the other estimator via the +# :class:`torchrl.objectives.utils.ValueEstimators` Enum class, which contains +# pointers to all the value estimators implemented. Let us define the default +# value function here. We will take the simplest version (TD(0)), and show later +# on how this can be changed. + +from torchrl.objectives.utils import ValueEstimators + +default_value_estimator = ValueEstimators.TD0 + +############################################################################### +# We also need to give some instructions to DDPG on how to build the value +# estimator, depending on the user query. Depending on the estimator provided, +# we will build the corresponding module to be used at train time: + +from torchrl.objectives.utils import default_value_kwargs +from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator + + +def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): + hp = dict(default_value_kwargs(value_type)) + if hasattr(self, "gamma"): + hp["gamma"] = self.gamma + hp.update(hyperparams) + value_key = "state_action_value" + if value_type == ValueEstimators.TD1: + self._value_estimator = TD1Estimator( + value_network=self.actor_critic, value_key=value_key, **hp + ) + elif value_type == ValueEstimators.TD0: + self._value_estimator = TD0Estimator( + value_network=self.actor_critic, value_key=value_key, **hp + ) + elif value_type == ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type == ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator( + value_network=self.actor_critic, value_key=value_key, **hp + ) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + +############################################################################### +# The ``make_value_estimator`` method can but does not need to be called: if +# not, the :class:`torchrl.objectives.LossModule` will query this method with +# its default estimator. +# +# The actor loss method +# ~~~~~~~~~~~~~~~~~~~~~ +# +# The central piece of an RL algorithm is the training loss for the actor. +# In the case of DDPG, this function is quite simple: we just need to compute +# the value associated with an action computed using the policy and optimize +# the actor weights to maximise this value. +# +# When computing this value, we must make sure to take the value parameters out +# of the graph, otherwise the actor and value loss will be mixed up. +# For this, the :func:`torchrl.objectives.utils.hold_out_params` function +# can be used. + + +def _loss_actor( + self, + tensordict, +) -> torch.Tensor: + td_copy = tensordict.select(*self.actor_in_keys) + # Get an action from the actor network + td_copy = self.actor_network( + td_copy, + ) + # get the value associated with that action + td_copy = self.value_network( + td_copy, + params=self.value_network_params.detach(), + ) + return -td_copy.get("state_action_value") + + +############################################################################### +# The value loss method +# ~~~~~~~~~~~~~~~~~~~~~ +# +# We now need to optimize our value network parameters. +# To do this, we will rely on the value estimator of our class: +# + +from torchrl.objectives.utils import distance_loss + + +def _loss_value( + self, + tensordict, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + td_copy = tensordict.clone() + + # V(s, a) + self.value_network(td_copy, params=self.value_network_params) + pred_val = td_copy.get("state_action_value").squeeze(-1) + + # we manually reconstruct the parameters of the actor-critic, where the first + # set of parameters belongs to the actor and the second to the value function. + target_params = TensorDict( + { + "module": { + "0": self.target_actor_network_params, + "1": self.target_value_network_params, + } + }, + batch_size=self.target_actor_network_params.batch_size, + device=self.target_actor_network_params.device, + ) + target_value = self.value_estimator.value_estimate( + tensordict, target_params=target_params + ).squeeze(-1) + + # Computes the value loss: L2, L1 or smooth L1 depending on self.loss_funtion + loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_funtion) + td_error = (pred_val - target_value).pow(2) + + return loss_value, td_error, pred_val, target_value + + +############################################################################### +# Putting things together in a forward call +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The only missing piece is the forward method, which will glue together the +# value and actor loss, collect the cost values and write them in a tensordict +# delivered to the user. + +from tensordict.tensordict import TensorDict, TensorDictBase + + +def _forward(self, input_tensordict: TensorDictBase) -> TensorDict: + loss_value, td_error, pred_val, target_value = self.loss_value( + input_tensordict, + ) + td_error = td_error.detach() + td_error = td_error.unsqueeze(input_tensordict.ndimension()) + if input_tensordict.device is not None: + td_error = td_error.to(input_tensordict.device) + input_tensordict.set( + "td_error", + td_error, + inplace=True, + ) + loss_actor = self.loss_actor(input_tensordict) + return TensorDict( + source={ + "loss_actor": loss_actor.mean(), + "loss_value": loss_value.mean(), + "pred_value": pred_val.mean().detach(), + "target_value": target_value.mean().detach(), + "pred_value_max": pred_val.max().detach(), + "target_value_max": target_value.max().detach(), + }, + batch_size=[], + ) + + +from torchrl.objectives import LossModule + + +class DDPGLoss(LossModule): + default_value_estimator = default_value_estimator + make_value_estimator = make_value_estimator + + __init__ = _init + forward = _forward + loss_value = _loss_value + loss_actor = _loss_actor + + +############################################################################### +# Now that we have our loss, we can use it to train a policy to solve a +# control task. +# # Environment # ----------- # # In most algorithms, the first thing that needs to be taken care of is the -# construction of the environmet as it conditions the remainder of the +# construction of the environment as it conditions the remainder of the # training script. # # For this example, we will be using the ``"cheetah"`` task. The goal is to make @@ -118,15 +411,18 @@ # # env = GymEnv("HalfCheetah-v4", from_pixels=True, pixels_only=True) # -# We write a :func:`make_env` helper funciton that will create an environment +# We write a :func:`make_env` helper function that will create an environment # with either one of the two backends considered above (dm-control or gym). # +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.libs.gym import GymEnv + env_library = None env_name = None -def make_env(): +def make_env(from_pixels=False): """Create a base env.""" global env_library global env_name @@ -145,9 +441,9 @@ def make_env(): env_kwargs = { "device": device, - "frame_skip": frame_skip, "from_pixels": from_pixels, "pixels_only": from_pixels, + "frame_skip": 2, } env = env_library(*env_args, **env_kwargs) return env @@ -155,7 +451,7 @@ def make_env(): ############################################################################### # Transforms -# ^^^^^^^^^^ +# ~~~~~~~~~~ # # Now that we have a base environment, we may want to modify its representation # to make it more policy-friendly. In TorchRL, transforms are appended to the @@ -182,6 +478,17 @@ def make_env(): # take care of computing the normalizing constants later on. # +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, + ParallelEnv, + RewardScaling, + StepCounter, + TransformedEnv, +) + def make_transformed_env( env, @@ -227,36 +534,14 @@ def make_transformed_env( ) ) - return env - - -############################################################################### -# Normalization of the observations -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# To compute the normalizing statistics, we run an arbitrary number of random -# steps in the environment and compute the mean and standard deviation of the -# collected observations. The :func:`ObservationNorm.init_stats()` method can -# be used for this purpose. To get the summary statistics, we create a dummy -# environment and run it for a given number of steps, collect data over a given -# number of steps and compute its summary statistics. -# - + env.append_transform(StepCounter(max_frames_per_traj)) -def get_env_stats(): - """Gets the stats of an environment.""" - proof_env = make_transformed_env(make_env()) - proof_env.set_seed(seed) - t = proof_env.transform[2] - t.init_stats(init_env_steps) - transform_state_dict = t.state_dict() - proof_env.close() - return transform_state_dict + return env ############################################################################### # Parallel execution -# ^^^^^^^^^^^^^^^^^^ +# ~~~~~~~~~~~~~~~~~~ # # The following helper function allows us to run environments in parallel. # Running environments in parallel can significantly speed up the collection @@ -282,6 +567,7 @@ def get_env_stats(): def parallel_env_constructor( + env_per_collector, transform_state_dict, ): if env_per_collector == 1: @@ -310,36 +596,108 @@ def make_t_env(): return env +# The backend can be gym or dm_control +backend = "gym" + +############################################################################### +# .. note:: +# ``frame_skip`` batches multiple step together with a single action +# If > 1, the other frame counts (e.g. frames_per_batch, total_frames) need to +# be adjusted to have a consistent total number of frames collected across +# experiments. This is important as raising the frame-skip but keeping the +# total number of frames unchanged may seem like cheating: all things compared, +# a dataset of 10M elements collected with a frame-skip of 2 and another with +# a frame-skip of 1 actually have a ratio of interactions with the environment +# of 2:1! In a nutshell, one should be cautious about the frame-count of a +# training script when dealing with frame skipping as this may lead to +# biased comparisons between training strategies. +# + +############################################################################### +# Scaling the reward helps us control the signal magnitude for a more +# efficient learning. +reward_scaling = 5.0 + +############################################################################### +# We also define when a trajectory will be truncated. A thousand steps (500 if +# frame-skip = 2) is a good number to use for cheetah: + +max_frames_per_traj = 500 + +############################################################################### +# Normalization of the observations +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To compute the normalizing statistics, we run an arbitrary number of random +# steps in the environment and compute the mean and standard deviation of the +# collected observations. The :func:`ObservationNorm.init_stats()` method can +# be used for this purpose. To get the summary statistics, we create a dummy +# environment and run it for a given number of steps, collect data over a given +# number of steps and compute its summary statistics. +# + + +def get_env_stats(): + """Gets the stats of an environment.""" + proof_env = make_transformed_env(make_env()) + t = proof_env.transform[2] + t.init_stats(init_env_steps) + transform_state_dict = t.state_dict() + proof_env.close() + return transform_state_dict + + +############################################################################### +# Normalization stats +# ~~~~~~~~~~~~~~~~~~~ +# Number of random steps used as for stats computation using ObservationNorm + +init_env_steps = 5000 + +transform_state_dict = get_env_stats() + +############################################################################### +# Number of environments in each data collector +env_per_collector = 4 + +############################################################################### +# We pass the stats computed earlier to normalize the output of our +# environment: + +parallel_env = parallel_env_constructor( + env_per_collector=env_per_collector, + transform_state_dict=transform_state_dict, +) + + +from torchrl.data import CompositeSpec + ############################################################################### # Building the model # ------------------ # -# We now turn to the setup of the model and loss function. DDPG requires a +# We now turn to the setup of the model. As we have seen, DDPG requires a # value network, trained to estimate the value of a state-action pair, and a # parametric actor that learns how to select actions that maximize this value. -# In this tutorial, we will be using two independent networks for these -# components. # -# Recall that building a torchrl module requires two steps: +# Recall that building a TorchRL module requires two steps: # -# - writing the :class:`torch.nn.Module` that will be used as network +# - writing the :class:`torch.nn.Module` that will be used as network, # - wrapping the network in a :class:`tensordict.nn.TensorDictModule` where the # data flow is handled by specifying the input and output keys. # # In more complex scenarios, :class:`tensordict.nn.TensorDictSequential` can # also be used. # -# In :func:`make_ddpg_actor`, we use a :class:`torchrl.modules.ProbabilisticActor` -# object to wrap our policy network. Since DDPG is a deterministic algorithm, -# this is not strictly necessary. We rely on this class to map the output -# action to the appropriate domain. Alternatively, one could perfectly use a -# non-linearity such as :class:`torch.tanh` to map the output to the right -# domain. # # The Q-Value network is wrapped in a :class:`torchrl.modules.ValueOperator` # that automatically sets the ``out_keys`` to ``"state_action_value`` for q-value # networks and ``state_value`` for other value networks. # +# TorchRL provides a built-in version of the DDPG networks as presented in the +# original paper. These can be found under :class:`torchrl.modules.DdpgMlpActor` +# and :class:`torchrl.modules.DdpgMlpQNet`. +# # Since we use lazy modules, it is necessary to materialize the lazy modules # before being able to move the policy from device to device and achieve other # operations. Hence, it is good practice to run the modules with a small @@ -347,6 +705,16 @@ def make_t_env(): # environment specs. # +from torchrl.modules import ( + ActorCriticWrapper, + DdpgMlpActor, + DdpgMlpQNet, + OrnsteinUhlenbeckProcessWrapper, + ProbabilisticActor, + TanhDelta, + ValueOperator, +) + def make_ddpg_actor( transform_state_dict, @@ -357,37 +725,29 @@ def make_ddpg_actor( proof_environment.transform[2].load_state_dict(transform_state_dict) env_specs = proof_environment.specs - out_features = env_specs["input_spec"]["action"].shape[0] + out_features = env_specs["input_spec"]["action"].shape[-1] - actor_net = MLP( - num_cells=[num_cells] * num_layers, - activation_class=nn.Tanh, - out_features=out_features, + actor_net = DdpgMlpActor( + action_dim=out_features, ) + in_keys = ["observation_vector"] out_keys = ["param"] - actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys) + actor = TensorDictModule( + actor_net, + in_keys=in_keys, + out_keys=out_keys, + ) - # We use a ProbabilisticActor to make sure that we map the network output - # to the right space using a TanhDelta distribution. actor = ProbabilisticActor( - module=actor_module, + actor, + distribution_class=TanhDelta, in_keys=["param"], spec=CompositeSpec(action=env_specs["input_spec"]["action"]), - safe=True, - distribution_class=TanhDelta, - distribution_kwargs={ - "min": env_specs["input_spec"]["action"].space.minimum, - "max": env_specs["input_spec"]["action"].space.maximum, - }, ).to(device) - q_net = MLP( - num_cells=[num_cells] * num_layers, - activation_class=nn.Tanh, - out_features=1, - ) + q_net = DdpgMlpQNet() in_keys = in_keys + ["action"] qnet = ValueOperator( @@ -395,18 +755,113 @@ def make_ddpg_actor( module=q_net, ).to(device) - # init: since we have lazy layers, we should run the network - # once to initialize them - with torch.no_grad(), set_exploration_mode("random"): - td = proof_environment.fake_tensordict() - td = td.expand((*td.shape, 2)) - td = td.to(device) - actor(td) - qnet(td) - + # init lazy moduless + qnet(actor(proof_environment.reset())) return actor, qnet +actor, qnet = make_ddpg_actor( + transform_state_dict=transform_state_dict, + device=device, +) + +############################################################################### +# Exploration +# ~~~~~~~~~~~ +# +# The policy is wrapped in a :class:`torchrl.modules.OrnsteinUhlenbeckProcessWrapper` +# exploration module, as suggesed in the original paper. +# Let's define the number of frames before OU noise reaches its minimum value +annealing_frames = 1_000_000 + +actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor, + annealing_num_steps=annealing_frames, +).to(device) +if device == torch.device("cpu"): + actor_model_explore.share_memory() + + +############################################################################### +# Data collector +# -------------- +# +# TorchRL provides specialized classes to help you collect data by executing +# the policy in the environment. These "data collectors" iteratively compute +# the action to be executed at a given time, then execute a step in the +# environment and reset it when required. +# Data collectors are designed to help developers have a tight control +# on the number of frames per batch of data, on the (a)sync nature of this +# collection and on the resources allocated to the data collection (e.g. GPU, +# number of workers etc). +# +# Here we will use +# :class:`torchrl.collectors.MultiaSyncDataCollector`, a data collector that +# will be executed in an async manner (i.e. data will be collected while +# the policy is being optimized). With the :class:`MultiaSyncDataCollector`, +# multiple workers are running rollouts separately. When a batch is asked, it +# is gathered from the first worker that can provide it. +# +# The parameters to specify are: +# +# - the list of environment creation functions, +# - the policy, +# - the total number of frames before the collector is considered empty, +# - the maximum number of frames per trajectory (useful for non-terminating +# environments, like dm_control ones). +# .. note:: +# The ``max_frames_per_traj`` passed to the collector will have the effect +# of registering a new :class:`torchrl.envs.StepCounter` transform +# with the environment used for inference. We can achieve the same result +# manually, as we do in this script. +# +# One should also pass: +# +# - the number of frames in each batch collected, +# - the number of random steps executed independently from the policy, +# - the devices used for policy execution +# - the devices used to store data before the data is passed to the main +# process. +# +# The total frames we will use during training should be around 1M. +total_frames = 10_000 # 1_000_000 + +############################################################################### +# The number of frames returned by the collector at each iteration of the outer +# loop is equal to the length of each sub-trajectories times the number of envs +# run in parallel in each collector. +# +# In other words, we expect batches from the collector to have a shape +# ``[env_per_collector, traj_len]`` where +# ``traj_len=frames_per_batch/env_per_collector``: +# +traj_len = 200 +frames_per_batch = env_per_collector * traj_len +init_random_frames = 5000 +num_collectors = 2 + +from torchrl.collectors import MultiaSyncDataCollector + +collector = MultiaSyncDataCollector( + create_env_fn=[ + parallel_env, + ] + * num_collectors, + policy=actor_model_explore, + total_frames=total_frames, + # max_frames_per_traj=max_frames_per_traj, # this is achieved by the env constructor + frames_per_batch=frames_per_batch, + init_random_frames=init_random_frames, + reset_at_each_iter=False, + split_trajs=False, + device=device, + # device for execution + storing_device=device, + # device where data will be stored and passed + update_at_each_batch=False, + exploration_mode="random", +) + ############################################################################### # Evaluator: building your recorder object # ---------------------------------------- @@ -418,25 +873,42 @@ def make_ddpg_actor( # from these simulations. # # The following helper function builds this object: +from torchrl.trainers import Recorder -def make_recorder(actor_model_explore, transform_state_dict): +def make_recorder(actor_model_explore, transform_state_dict, record_interval): base_env = make_env() - recorder = make_transformed_env(base_env) - recorder.transform[2].init_stats(3) - recorder.transform[2].load_state_dict(transform_state_dict) + environment = make_transformed_env(base_env) + environment.transform[2].init_stats( + 3 + ) # must be instantiated to load the state dict + environment.transform[2].load_state_dict(transform_state_dict) recorder_obj = Recorder( record_frames=1000, - frame_skip=frame_skip, policy_exploration=actor_model_explore, - recorder=recorder, - exploration_mode="mean", + environment=environment, + exploration_mode="mode", record_interval=record_interval, ) return recorder_obj +############################################################################### +# We will be recording the performance every 10 batch collected +record_interval = 10 + +recorder = make_recorder( + actor_model_explore, transform_state_dict, record_interval=record_interval +) + +from torchrl.data.replay_buffers import ( + LazyMemmapStorage, + PrioritizedSampler, + RandomSampler, + TensorDictReplayBuffer, +) + ############################################################################### # Replay buffer # ------------- @@ -452,8 +924,10 @@ def make_recorder(actor_model_explore, transform_state_dict): # hyperparameters: # +from torchrl.envs import RandomCropTensorDict + -def make_replay_buffer(buffer_size, prefetch=3): +def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb=False): if prb: sampler = PrioritizedSampler( max_capacity=buffer_size, @@ -466,320 +940,157 @@ def make_replay_buffer(buffer_size, prefetch=3): storage=LazyMemmapStorage( buffer_size, scratch_dir=buffer_scratch_dir, - device=device, ), + batch_size=batch_size, sampler=sampler, pin_memory=False, prefetch=prefetch, + transform=RandomCropTensorDict(random_crop_len, sample_dim=1), ) return replay_buffer ############################################################################### -# Hyperparameters -# --------------- -# -# After having written our helper functions, it is time to set the -# experiment hyperparameters: - -############################################################################### -# Environment -# ^^^^^^^^^^^ - -# The backend can be gym or dm_control -backend = "gym" - -exp_name = "cheetah" - -# frame_skip batches multiple step together with a single action -# If > 1, the other frame counts (e.g. frames_per_batch, total_frames) need to -# be adjusted to have a consistent total number of frames collected across -# experiments. -frame_skip = 2 -from_pixels = False -# Scaling the reward helps us control the signal magnitude for a more -# efficient learning. -reward_scaling = 5.0 - -# Number of random steps used as for stats computation using ObservationNorm -init_env_steps = 1000 - -# Exploration: Number of frames before OU noise becomes null -annealing_frames = 1000000 // frame_skip - -############################################################################### -# Collection -# ^^^^^^^^^^ - -# We will execute the policy on cuda if available -device = ( - torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0") -) - -# Number of environments in each data collector -env_per_collector = 2 - -# Total frames we will use during training. Scale up to 500K - 1M for a more -# meaningful training -total_frames = 5000 // frame_skip -# Number of frames returned by the collector at each iteration of the outer loop -frames_per_batch = env_per_collector * 1000 // frame_skip -max_frames_per_traj = 1000 // frame_skip -init_random_frames = 0 -# We'll be using the MultiStep class to have a less myopic representation of -# upcoming states -n_steps_forward = 3 - -# record every 10 batch collected -record_interval = 10 - -############################################################################### -# Optimizer and optimization -# ^^^^^^^^^^^^^^^^^^^^^^^^^^ - -lr = 5e-4 -weight_decay = 0.0 -# UTD: Number of iterations of the inner loop -update_to_data = 32 -batch_size = 128 - -############################################################################### -# Model -# ^^^^^ - -gamma = 0.99 -tau = 0.005 # Decay factor for the target network - -# Network specs -num_cells = 64 -num_layers = 2 - -############################################################################### -# Replay buffer -# ^^^^^^^^^^^^^ +# We'll store the replay buffer in a temporary dirrectory on disk -# If True, a Prioritized replay buffer will be used -prb = True -# Number of frames stored in the buffer -buffer_size = min(total_frames, 1000000 // frame_skip) -buffer_scratch_dir = "/tmp/" +import tempfile -seed = 0 +tmpdir = tempfile.TemporaryDirectory() +buffer_scratch_dir = tmpdir.name ############################################################################### -# Initialization -# -------------- +# Replay buffer storage and batch size +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# TorchRL replay buffer counts the number of elements along the first dimension. +# Since we'll be feeding trajectories to our buffer, we need to adapt the buffer +# size by dividing it by the length of the sub-trajectories yielded by our +# data collector. +# Regarding the batch-size, our sampling strategy will consist in sampling +# trajectories of length ``traj_len=200`` before selecting sub-trajecotries +# or length ``random_crop_len=25`` on which the loss will be computed. +# This strategy balances the choice of storing whole trajectories of a certain +# length with the need for providing sampels with a sufficient heterogeneity +# to our loss. The following figure shows the dataflow from a collector +# that gets 8 frames in each batch with 2 environments run in parallel, +# feeds them to a replay buffer that contains 1000 trajectories and +# samples sub-trajectories of 2 time steps each. # -# To initialize the experiment, we first acquire the observation statistics, -# then build the networks, wrap them in an exploration wrapper (following the -# seminal DDPG paper, we used an Ornstein-Uhlenbeck process to add noise to the -# sampled actions). +# .. figure:: /_static/img/replaybuffer_traj.png +# :alt: Storign trajectories in the replay buffer +# +# Let's start with the number of frames stored in the buffer -# Seeding -torch.manual_seed(seed) -np.random.seed(seed) +def ceil_div(x, y): + return -x // (-y) -############################################################################### -# Normalization stats -# ^^^^^^^^^^^^^^^^^^^ -transform_state_dict = get_env_stats() +buffer_size = 1_000_000 +buffer_size = ceil_div(buffer_size, traj_len) ############################################################################### -# Models: policy and q-value network -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -actor, qnet = make_ddpg_actor( - transform_state_dict=transform_state_dict, - device=device, -) -if device == torch.device("cpu"): - actor.share_memory() +# Prioritized replay buffer is disabled by default +prb = False ############################################################################### -# We create a copy of the q-value network to be used as target network - -qnet_target = deepcopy(qnet).requires_grad_(False) +# We also need to define how many updates we'll be doing per batch of data +# collected. This is known as the update-to-data or UTD ratio: +update_to_data = 64 ############################################################################### -# The policy is wrapped in a :class:`torchrl.modules.OrnsteinUhlenbeckProcessWrapper` -# exploration module: - -actor_model_explore = OrnsteinUhlenbeckProcessWrapper( - actor, - annealing_num_steps=annealing_frames, -).to(device) -if device == torch.device("cpu"): - actor_model_explore.share_memory() +# We'll be feeding the loss with trajectories of length 25: +random_crop_len = 25 ############################################################################### -# Parallel environment creation -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# We pass the stats computed earlier to normalize the output of our -# environment: - -create_env_fn = parallel_env_constructor( - transform_state_dict=transform_state_dict, +# In the original paper, the authors perform one update with a batch of 64 +# elements for each frame collected. Here, we reproduce the same ratio +# but while realizing several updates at each batch collection. We +# adapt our batch-size to achieve the same number of update-per-frame ratio: + +batch_size = ceil_div(64 * frames_per_batch, update_to_data * random_crop_len) + +replay_buffer = make_replay_buffer( + buffer_size=buffer_size, + batch_size=batch_size, + random_crop_len=random_crop_len, + prefetch=3, + prb=prb, ) ############################################################################### -# Data collector -# ^^^^^^^^^^^^^^ -# -# TorchRL provides specialized classes to help you collect data by executing -# the policy in the environment. These "data collectors" iteratively compute -# the action to be executed at a given time, then execute a step in the -# environment and reset it when required. -# Data collectors are designed to help developers have a tight control -# on the number of frames per batch of data, on the (a)sync nature of this -# collection and on the resources allocated to the data collection (e.g. GPU, -# number of workers etc). -# -# Here we will use -# :class:`torchrl.collectors.MultiaSyncDataCollector`, a data collector that -# will be executed in an async manner (i.e. data will be collected while -# the policy is being optimized). With the :class:`MultiaSyncDataCollector`, -# multiple workers are running rollouts separately. When a batch is asked, it -# is gathered from the first worker that can provide it. -# -# The parameters to specify are: -# -# - the list of environment creation functions, -# - the policy, -# - the total number of frames before the collector is considered empty, -# - the maximum number of frames per trajectory (useful for non-terminating -# environments, like dm_control ones). -# -# One should also pass: +# Loss module construction +# ------------------------ # -# - the number of frames in each batch collected, -# - the number of random steps executed independently from the policy, -# - the devices used for policy execution -# - the devices used to store data before the data is passed to the main -# process. +# We build our loss module with the actor and qnet we've just created. +# Because we have target parameters to update, we _must_ create a target network +# updater. # -# Collectors also accept post-processing hooks. -# For instance, the :class:`torchrl.data.postprocs.MultiStep` class passed as -# ``postproc`` makes it so that the rewards of the ``n`` upcoming steps are -# summed (with some discount factor) and the next observation is changed to -# be the n-step forward observation. One could pass other transforms too: -# using :class:`tensordict.nn.TensorDictModule` and -# :class:`tensordict.nn.TensorDictSequential` we can seamlessly append a -# wide range of transforms to our collector. - -if n_steps_forward > 0: - multistep = MultiStep(n_steps=n_steps_forward, gamma=gamma) -else: - multistep = None -collector = MultiaSyncDataCollector( - create_env_fn=[create_env_fn, create_env_fn], - policy=actor_model_explore, - total_frames=total_frames, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - init_random_frames=init_random_frames, - reset_at_each_iter=False, - postproc=multistep, - split_trajs=True, - devices=[device, device], # device for execution - storing_devices=[device, device], # device where data will be stored and passed - pin_memory=False, - update_at_each_batch=False, - exploration_mode="random", -) +gamma = 0.99 +lmbda = 0.9 +tau = 0.001 # Decay factor for the target network -collector.set_seed(seed) +loss_module = DDPGLoss(actor, qnet) ############################################################################### -# Replay buffer -# ^^^^^^^^^^^^^ -# - -replay_buffer = make_replay_buffer(buffer_size, prefetch=3) +# let's use the TD(lambda) estimator! +loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda) ############################################################################### -# Recorder -# ^^^^^^^^ - -recorder = make_recorder(actor_model_explore, transform_state_dict) +# .. note:: +# Off-policy usually dictates a TD(0) estimator. Here, we use a TD(:math:`\lambda`) +# estimator, which will introduce some bias as the trajectory that follows +# a certain state has been collected with an outdated policy. +# This trick, as the multi-step trick that can be used during data collection, +# are alternative versions of "hacks" that we usually find to work well in +# practice despite the fact that they introduce some bias in the return +# estimates. +# +# Target network updater +# ^^^^^^^^^^^^^^^^^^^^^^ +# +# Target networks are a crucial part of off-policy RL algorithms. +# Updating the target network parameters is made easy thanks to the +# :class:`torchrl.objectives.HardUpdate` and :class:`torchrl.objectives.SoftUpdate` +# classes. They're built with the loss module as argument, and the update is +# achieved via a call to `updater.step()` at the appropriate location in the +# training loop. + +from torchrl.objectives.utils import SoftUpdate + +target_net_updater = SoftUpdate(loss_module, eps=1 - tau) +# This class will raise an error if `init_` is not called first. +target_net_updater.init_() ############################################################################### # Optimizer -# ^^^^^^^^^ +# ~~~~~~~~~ # -# Finally, we will use the Adam optimizer for the policy and value network, -# with the same learning rate for both. +# Finally, we will use the Adam optimizer for the policy and value network: -optimizer_actor = optim.Adam(actor.parameters(), lr=lr, weight_decay=weight_decay) -optimizer_qnet = optim.Adam(qnet.parameters(), lr=lr, weight_decay=weight_decay) -total_collection_steps = total_frames // frames_per_batch +from torch import optim -scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer_actor, T_max=total_collection_steps +optimizer_actor = optim.Adam( + loss_module.actor_network_params.values(True, True), lr=1e-4, weight_decay=0.0 ) -scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer_qnet, T_max=total_collection_steps +optimizer_value = optim.Adam( + loss_module.value_network_params.values(True, True), lr=1e-3, weight_decay=1e-2 ) +total_collection_steps = total_frames // frames_per_batch ############################################################################### # Time to train the policy # ------------------------ # -# Some notes about the following training loop: -# -# - :func:`torchrl.objectives.utils.hold_out_net` is a TorchRL context manager -# that temporarily sets :func:`torch.Tensor.requires_grad_()` to False for -# a designated set of network parameters. This is used to -# prevent :func:`torch.Tensor.backward()`` from writing gradients on -# parameters that need not to be differentiated given the loss at hand. -# - The value network is designed using the -# :class:`torchrl.modules.ValueOperator` subclass from -# :class:`tensordict.nn.TensorDictModule` class. As explained earlier, -# this class will write a ``"state_action_value"`` entry if one of its -# ``in_keys`` is named ``"action"``, otherwise it will assume that only the -# state-value is returned and the output key will simply be ``"state_value"``. -# In the case of DDPG, the value if of the state-action pair, -# hence the ``"state_action_value"`` will be used. -# - The :func:`torchrl.envs.utils.step_mdp(tensordict)` helper function is the -# equivalent of the ``obs = next_obs`` command found in multiple RL -# algorithms. It will return a new :class:`tensordict.TensorDict` instance -# that contains all the data that will need to be used in the next iteration. -# This makes it possible to pass this new tensordict to the policy or -# value network. -# - When using prioritized replay buffer, a priority key is added to the -# sampled tensordict (named ``"td_error"`` by default). Then, this -# TensorDict will be fed back to the replay buffer using the -# :func:`torchrl.data.replay_buffers.TensorDictReplayBuffer.update_tensordict_priority` -# method. Under the hood, this method will read the index present in the -# TensorDict as well as the priority value, and update its list of priorities -# at these indices. -# - TorchRL provides optimized versions of the loss functions (such as this one) -# where one only needs to pass a sampled tensordict and obtains a dictionary -# of losses and metadata in return (see :mod:`torchrl.objectives` for more -# context). Here we write the full loss function in the optimization loop -# for transparency. -# Similarly, the target network updates are written explicitly but -# TorchRL provides a couple of dedicated classes for this -# (see :class:`torchrl.objectives.SoftUpdate` and -# :class:`torchrl.objectives.HardUpdate`). -# - After each collection of data, we call :func:`collector.update_policy_weights_()`, -# which will update the policy network weights on the data collector. If the -# code is executed on cpu or with a single cuda device, this part can be -# omitted. If the collector is executed on another device, then its weights -# must be synced with those on the main, training process and this method -# should be incorporated in the training loop (ideally early in the loop in -# async settings, and at the end of it in sync settings). +# The training loop is pretty straightforward now that we have built all the +# modules we need. +# rewards = [] rewards_eval = [] # Main loop -norm_factor_training = ( - sum(gamma**i for i in range(n_steps_forward)) if n_steps_forward else 1 -) collected_frames = 0 pbar = tqdm.tqdm(total=total_frames) @@ -794,13 +1105,7 @@ def make_replay_buffer(buffer_size, prefetch=3): pbar.update(tensordict.numel()) # extend the replay buffer with the new data - if ("collector", "mask") in tensordict.keys(True): - # if multi-step, a mask is present to help filter padded values - current_frames = tensordict["collector", "mask"].sum() - tensordict = tensordict[tensordict.get(("collector", "mask"))] - else: - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() + current_frames = tensordict.numel() collected_frames += current_frames replay_buffer.extend(tensordict.cpu()) @@ -808,73 +1113,61 @@ def make_replay_buffer(buffer_size, prefetch=3): if collected_frames >= init_random_frames: for _ in range(update_to_data): # sample from replay buffer - sampled_tensordict = replay_buffer.sample(batch_size).clone() - - # compute loss for qnet and backprop - with hold_out_net(actor): - # get next state value - next_tensordict = step_mdp(sampled_tensordict) - qnet_target(actor(next_tensordict)) - next_value = next_tensordict["state_action_value"] - assert not next_value.requires_grad - value_est = ( - sampled_tensordict["next", "reward"] - + gamma * (1 - sampled_tensordict["next", "done"].float()) * next_value - ) - value = qnet(sampled_tensordict)["state_action_value"] - value_loss = (value - value_est).pow(2).mean() - # we write the td_error in the sampled_tensordict for priority update - # because the indices of the samples is tracked in sampled_tensordict - # and the replay buffer will know which priorities to update. - sampled_tensordict["td_error"] = (value - value_est).pow(2).detach() - value_loss.backward() - - optimizer_qnet.step() - optimizer_qnet.zero_grad() - - # compute loss for actor and backprop: - # the actor must maximise the state-action value, hence the loss - # is the neg value of this. - sampled_tensordict_actor = sampled_tensordict.select(*actor.in_keys) - with hold_out_net(qnet): - qnet(actor(sampled_tensordict_actor)) - actor_loss = -sampled_tensordict_actor["state_action_value"] - actor_loss.mean().backward() + sampled_tensordict = replay_buffer.sample().to(device) + + # Compute loss + loss_dict = loss_module(sampled_tensordict) + # optimize + loss_dict["loss_actor"].backward() + gn1 = torch.nn.utils.clip_grad_norm_( + loss_module.actor_network_params.values(True, True), 10.0 + ) optimizer_actor.step() optimizer_actor.zero_grad() - # update qnet_target params - for (p_in, p_dest) in zip(qnet.parameters(), qnet_target.parameters()): - p_dest.data.copy_(tau * p_in.data + (1 - tau) * p_dest.data) - for (b_in, b_dest) in zip(qnet.buffers(), qnet_target.buffers()): - b_dest.data.copy_(tau * b_in.data + (1 - tau) * b_dest.data) + loss_dict["loss_value"].backward() + gn2 = torch.nn.utils.clip_grad_norm_( + loss_module.value_network_params.values(True, True), 10.0 + ) + optimizer_value.step() + optimizer_value.zero_grad() + + gn = (gn1**2 + gn2**2) ** 0.5 # update priority if prb: replay_buffer.update_tensordict_priority(sampled_tensordict) + # update target network + target_net_updater.step() rewards.append( ( i, - tensordict["next", "reward"].mean().item() - / norm_factor_training - / frame_skip, + tensordict["next", "reward"].mean().item(), ) ) td_record = recorder(None) if td_record is not None: rewards_eval.append((i, td_record["r_evaluation"].item())) - if len(rewards_eval): + if len(rewards_eval) and collected_frames >= init_random_frames: + target_value = loss_dict["target_value"].item() + loss_value = loss_dict["loss_value"].item() + loss_actor = loss_dict["loss_actor"].item() + rn = sampled_tensordict["next", "reward"].mean().item() + rs = sampled_tensordict["next", "reward"].std().item() pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), reward eval: reward: {rewards_eval[-1][1]: 4.4f}" + f"reward: {rewards[-1][1]: 4.2f} (r0 = {r0: 4.2f}), " + f"reward eval: reward: {rewards_eval[-1][1]: 4.2f}, " + f"reward normalized={rn :4.2f}/{rs :4.2f}, " + f"grad norm={gn: 4.2f}, " + f"loss_value={loss_value: 4.2f}, " + f"loss_actor={loss_actor: 4.2f}, " + f"target value: {target_value: 4.2f}" ) # update the exploration strategy actor_model_explore.step(current_frames) - if collected_frames >= init_random_frames: - scheduler1.step() - scheduler2.step() collector.shutdown() del collector @@ -886,8 +1179,11 @@ def make_replay_buffer(buffer_size, prefetch=3): # We make a simple plot of the average rewards during training. We can observe # that our policy learned quite well to solve the task. # -# **Note**: As already mentioned above, to get a more reasonable performance, -# use a greater value for ``total_frames`` e.g. 1M. +# .. note:: +# As already mentioned above, to get a more reasonable performance, +# use a greater value for ``total_frames`` e.g. 1M. + +from matplotlib import pyplot as plt plt.figure() plt.plot(*zip(*rewards), label="training") @@ -898,265 +1194,16 @@ def make_replay_buffer(buffer_size, prefetch=3): plt.tight_layout() ############################################################################### -# Sampling trajectories and using TD(lambda) -# ------------------------------------------ +# Conclusion +# ---------- # -# TD(lambda) is known to be less biased than the regular TD-error we used in -# the previous example. To use it, however, we need to sample trajectories and -# not single transitions. +# In this tutorial, we have learnt how to code a loss module in TorchRL given +# the concrete example of DDPG. # -# We modify the previous example to make this possible. -# -# The first modification consists in building a replay buffer that stores -# trajectories (and not transitions). -# -# Specifically, we'll collect trajectories of (at most) -# 250 steps (note that the total trajectory length is actually 1000 frames, but -# we collect batches of 500 transitions obtained over 2 environments running in -# parallel, hence only 250 steps per trajectory are collected at any given -# time). Hence, we'll divide our replay buffer size by 250: - -buffer_size = 100000 // frame_skip // 250 -print("the new buffer size is", buffer_size) -batch_size_traj = max(4, batch_size // 250) -print("the new batch size for trajectories is", batch_size_traj) - -n_steps_forward = 0 # disable multi-step for simplicity - -############################################################################### -# The following code is identical to the initialization we made earlier: - -torch.manual_seed(seed) -np.random.seed(seed) - -# get stats for normalization -transform_state_dict = get_env_stats() - -# Actor and qnet instantiation -actor, qnet = make_ddpg_actor( - transform_state_dict=transform_state_dict, - device=device, -) -if device == torch.device("cpu"): - actor.share_memory() - -# Target network -qnet_target = deepcopy(qnet).requires_grad_(False) - -# Exploration wrappers: -actor_model_explore = OrnsteinUhlenbeckProcessWrapper( - actor, - annealing_num_steps=annealing_frames, -).to(device) -if device == torch.device("cpu"): - actor_model_explore.share_memory() - -# Environment setting: -create_env_fn = parallel_env_constructor( - transform_state_dict=transform_state_dict, -) -# Batch collector: -collector = MultiaSyncDataCollector( - create_env_fn=[create_env_fn, create_env_fn], - policy=actor_model_explore, - total_frames=total_frames, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - init_random_frames=init_random_frames, - reset_at_each_iter=False, - postproc=None, - split_trajs=False, - devices=[device, device], # device for execution - storing_devices=[device, device], # device where data will be stored and passed - seed=None, - pin_memory=False, - update_at_each_batch=False, - exploration_mode="random", -) -collector.set_seed(seed) - -# Replay buffer: -replay_buffer = make_replay_buffer(buffer_size, prefetch=0) - -# trajectory recorder -recorder = make_recorder(actor_model_explore, transform_state_dict) - -# Optimizers -optimizer_actor = optim.Adam(actor.parameters(), lr=lr, weight_decay=weight_decay) -optimizer_qnet = optim.Adam(qnet.parameters(), lr=lr, weight_decay=weight_decay) -total_collection_steps = total_frames // frames_per_batch - -scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer_actor, T_max=total_collection_steps -) -scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer_qnet, T_max=total_collection_steps -) - -############################################################################### -# The training loop needs to be slightly adapted. -# First, whereas before extending the replay buffer we used to flatten the -# collected data, this won't be the case anymore. To understand why, let's -# check the output shape of the data collector: - -for data in collector: - print(data.shape) - break - -############################################################################### -# We see that our data has shape ``[2, 250]`` as expected: 2 envs, each -# returning 250 frames. +# The key takeaways are: # -# Let's import the td_lambda function: +# - How to use the :class:`torchrl.objectives.LossModule` class to code up a new +# loss component; +# - How to use (or not) a target network, and how to update its parameters; +# - How to create an optimizer associated with a loss module. # - -from torchrl.objectives.value.functional import vec_td_lambda_advantage_estimate - -lmbda = 0.95 - -############################################################################### -# The training loop is roughly the same as before, with the exception that we -# don't flatten the collected data. Also, the sampling from the replay buffer -# is slightly different: We will collect at minimum four trajectories, compute -# the returns (TD(lambda)), then sample from these the values we'll be using -# to compute gradients. This ensures that do not have batches that are -# 'too big' but still compute an accurate return. -# - -rewards = [] -rewards_eval = [] - -# Main loop -norm_factor_training = ( - sum(gamma**i for i in range(n_steps_forward)) if n_steps_forward else 1 -) - -collected_frames = 0 -# # if tqdm is to be used -# pbar = tqdm.tqdm(total=total_frames) -r0 = None -for i, tensordict in enumerate(collector): - - # update weights of the inference policy - collector.update_policy_weights_() - - if r0 is None: - r0 = tensordict["next", "reward"].mean().item() - - # extend the replay buffer with the new data - current_frames = tensordict.numel() - collected_frames += current_frames - replay_buffer.extend(tensordict.cpu()) - - # optimization steps - if collected_frames >= init_random_frames: - for _ in range(update_to_data): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample(batch_size_traj) - # reset the batch size temporarily, and exclude index - # whose shape is incompatible with the new size - index = sampled_tensordict.get("index") - sampled_tensordict.exclude("index", inplace=True) - - # compute loss for qnet and backprop - with hold_out_net(actor): - # get next state value - next_tensordict = step_mdp(sampled_tensordict) - qnet_target(actor(next_tensordict.view(-1))).view( - sampled_tensordict.shape - ) - next_value = next_tensordict["state_action_value"] - assert not next_value.requires_grad - - # This is the crucial part: we'll compute the TD(lambda) - # instead of a simple single step estimate - done = sampled_tensordict["next", "done"] - reward = sampled_tensordict["next", "reward"] - value = qnet(sampled_tensordict.view(-1)).view(sampled_tensordict.shape)[ - "state_action_value" - ] - advantage = vec_td_lambda_advantage_estimate( - gamma, - lmbda, - value, - next_value, - reward, - done, - time_dim=sampled_tensordict.ndim - 1, - ) - # we sample from the values we have computed - rand_idx = torch.randint(0, advantage.numel(), (batch_size,)) - value_loss = advantage.view(-1)[rand_idx].pow(2).mean() - - # we write the td_error in the sampled_tensordict for priority update - # because the indices of the samples is tracked in sampled_tensordict - # and the replay buffer will know which priorities to update. - value_loss.backward() - - optimizer_qnet.step() - optimizer_qnet.zero_grad() - - # compute loss for actor and backprop: the actor must maximise the state-action value, hence the loss is the neg value of this. - sampled_tensordict_actor = sampled_tensordict.select(*actor.in_keys) - with hold_out_net(qnet): - qnet(actor(sampled_tensordict_actor.view(-1))).view( - sampled_tensordict.shape - ) - actor_loss = -sampled_tensordict_actor["state_action_value"] - actor_loss.view(-1)[rand_idx].mean().backward() - - optimizer_actor.step() - optimizer_actor.zero_grad() - - # update qnet_target params - for (p_in, p_dest) in zip(qnet.parameters(), qnet_target.parameters()): - p_dest.data.copy_(tau * p_in.data + (1 - tau) * p_dest.data) - for (b_in, b_dest) in zip(qnet.buffers(), qnet_target.buffers()): - b_dest.data.copy_(tau * b_in.data + (1 - tau) * b_dest.data) - - # update priority - sampled_tensordict.batch_size = [batch_size_traj] - sampled_tensordict["td_error"] = advantage.detach().pow(2).mean(1) - sampled_tensordict["index"] = index - if prb: - replay_buffer.update_tensordict_priority(sampled_tensordict) - - rewards.append( - ( - i, - tensordict["next", "reward"].mean().item() - / norm_factor_training - / frame_skip, - ) - ) - td_record = recorder(None) - if td_record is not None: - rewards_eval.append((i, td_record["r_evaluation"].item())) - # if len(rewards_eval): - # pbar.set_description(f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), reward eval: reward: {rewards_eval[-1][1]: 4.4f}") - - # update the exploration strategy - actor_model_explore.step(current_frames) - if collected_frames >= init_random_frames: - scheduler1.step() - scheduler2.step() - -collector.shutdown() -del create_env_fn -del collector - -############################################################################### -# We can observe that using TD(lambda) made our results considerably more -# stable for a similar training speed: -# -# **Note**: As already mentioned above, to get a more reasonable performance, -# use a greater value for ``total_frames`` e.g. 1000000. - -plt.figure() -plt.plot(*zip(*rewards), label="training") -plt.plot(*zip(*rewards_eval), label="eval") -plt.legend() -plt.xlabel("iter") -plt.ylabel("reward") -plt.tight_layout() -plt.title("TD-labmda DDPG results") diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 1b566ee09d7..4603cecf37f 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -1,20 +1,63 @@ # -*- coding: utf-8 -*- """ -Coding a pixel-based DQN using TorchRL -====================================== +TorchRL trainer: A DQN example +============================== **Author**: `Vincent Moens `_ """ ############################################################################## -# This tutorial will guide you through the steps to code DQN to solve the -# CartPole task from scratch. DQN -# (`Deep Q-Learning `_) was +# TorchRL provides a generic :class:`torchrl.trainers.Trainer` class to handle +# your training loop. The trainer executes a nested loop where the outer loop +# is the data collection and the inner loop consumes this data or some data +# retrieved from the replay buffer to train the model. +# At various points in this training loop, hooks can be attached and executed at +# given intervals. +# +# In this tutorial, we will be using the trainer class to train a DQN algorithm +# to solve the CartPole task from scratch. +# +# Main takeaways: +# +# - Building a trainer with its essential components: data collector, loss +# module, replay buffer and optimizer. +# - Adding hooks to a trainer, such as loggers, target network updaters and such. +# +# The trainer is fully customisable and offers a large set of functionalities. +# The tutorial is organised around its construction. +# We will be detailing how to build each of the components of the library first, +# and then put the pieces together using the :class:`torchrl.trainers.Trainer` +# class. +# +# Along the road, we will also focus on some other aspects of the library: +# +# - how to build an environment in TorchRL, including transforms (e.g. data +# normalization, frame concatenation, resizing and turning to grayscale) +# and parallel execution. Unlike what we did in the +# `DDPG tutorial `_, we +# will normalize the pixels and not the state vector. +# - how to design a :class:`torchrl.modules.QValueActor` object, i.e. an actor +# that estimates the action values and picks up the action with the highest +# estimated return; +# - how to collect data from your environment efficiently and store them +# in a replay buffer; +# - how to use multi-step, a simple preprocessing step for off-policy algorithms; +# - and finally how to evaluate your model. +# +# **Prerequisites**: We encourage you to get familiar with torchrl through the +# `PPO tutorial `_ first. +# +# DQN +# --- +# +# DQN (`Deep Q-Learning `_) was # the founding work in deep reinforcement learning. -# On a high level, the algorithm is quite simple: Q-learning consists in learning a table of -# state-action values in such a way that, when encountering any particular state, -# we know which action to pick just by searching for the action with the -# highest value. This simple setting requires the actions and states to be +# +# On a high level, the algorithm is quite simple: Q-learning consists in +# learning a table of state-action values in such a way that, when +# encountering any particular state, we know which action to pick just by +# searching for the one with the highest value. This simple setting +# requires the actions and states to be # discrete, otherwise a lookup table cannot be built. # # DQN uses a neural network that encodes a map from the state-action space to @@ -35,57 +78,28 @@ # .. figure:: /_static/img/cartpole_demo.gif # :alt: Cart Pole # -# **Prerequisites**: We encourage you to get familiar with torchrl through the -# `PPO tutorial `_ first. -# This tutorial is more complex and full-fleshed, but it may be . -# -# In this tutorial, you will learn: -# -# - how to build an environment in TorchRL, including transforms (e.g. data -# normalization, frame concatenation, resizing and turning to grayscale) -# and parallel execution. Unlike what we did in the -# `DDPG tutorial `_, we -# will normalize the pixels and not the state vector. -# - how to design a QValue actor, i.e. an actor that estimates the action -# values and picks up the action with the highest estimated return; -# - how to collect data from your environment efficiently and store them -# in a replay buffer; -# - how to store trajectories (and not transitions) in your replay buffer), -# and how to estimate returns using TD(lambda); -# - how to make a module functional and use ; -# - and finally how to evaluate your model. -# -# This tutorial assumes the reader is familiar with some of TorchRL -# primitives, such as :class:`tensordict.TensorDict` and -# :class:`tensordict.TensorDictModules`, although it -# should be sufficiently transparent to be understood without a deep -# understanding of these classes. -# # We do not aim at giving a SOTA implementation of the algorithm, but rather # to provide a high-level illustration of TorchRL features in the context # of this algorithm. # sphinx_gallery_start_ignore +import tempfile import warnings -from collections import defaultdict warnings.filterwarnings("ignore") # sphinx_gallery_end_ignore +import os +import uuid + import torch -import tqdm -from functorch import vmap -from matplotlib import pyplot as plt -from tensordict import TensorDict -from tensordict.nn import get_functional from torch import nn from torchrl.collectors import MultiaSyncDataCollector -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer from torchrl.envs import EnvCreator, ParallelEnv, RewardScaling, StepCounter from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms import ( CatFrames, - CatTensors, Compose, GrayScale, ObservationNorm, @@ -93,9 +107,18 @@ ToTensorImage, TransformedEnv, ) -from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import DuelingCnnDQNet, EGreedyWrapper, QValueActor +from torchrl.objectives import DQNLoss, SoftUpdate +from torchrl.record.loggers.csv import CSVLogger +from torchrl.trainers import ( + LogReward, + Recorder, + ReplayBufferTrainer, + Trainer, + UpdateWeights, +) + def is_notebook() -> bool: try: @@ -111,150 +134,84 @@ def is_notebook() -> bool: ############################################################################### -# Hyperparameters -# --------------- +# Let's get started with the various pieces we need for our algorithm: # -# Let's start with our hyperparameters. The following setting should work well -# in practice, and the performance of the algorithm should hopefully not be -# too sensitive to slight variations of these. - -device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu" - -############################################################################### -# Optimizer -# ^^^^^^^^^ - -# the learning rate of the optimizer -lr = 2e-3 -# the beta parameters of Adam -betas = (0.9, 0.999) -# Optimization steps per batch collected (aka UPD or updates per data) -n_optim = 8 - -############################################################################### -# DQN parameters -# ^^^^^^^^^^^^^^ - -############################################################################### -# gamma decay factor -gamma = 0.99 - -############################################################################### -# lambda decay factor (see second the part with TD(:math:`\lambda`) -lmbda = 0.95 - -############################################################################### -# Smooth target network update decay parameter. -# This loosely corresponds to a 1/(1-tau) interval with hard target network -# update -tau = 0.005 - -############################################################################### -# Data collection and replay buffer -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Values to be used for proper training have been commented. -# -# Total frames collected in the environment. In other implementations, the -# user defines a maximum number of episodes. -# This is harder to do with our data collectors since they return batches -# of N collected frames, where N is a constant. -# However, one can easily get the same restriction on number of episodes by -# breaking the training loop when a certain number -# episodes has been collected. -total_frames = 5000 # 500000 - -############################################################################### -# Random frames used to initialize the replay buffer. -init_random_frames = 100 # 1000 - -############################################################################### -# Frames in each batch collected. -frames_per_batch = 32 # 128 - -############################################################################### -# Frames sampled from the replay buffer at each optimization step -batch_size = 32 # 256 - -############################################################################### -# Size of the replay buffer in terms of frames -buffer_size = min(total_frames, 100000) - -############################################################################### -# Number of environments run in parallel in each data collector -num_workers = 2 # 8 -num_collectors = 2 # 4 - - -############################################################################### -# Environment and exploration -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# - An environment; +# - A policy (and related modules that we group under the "model" umbrella); +# - A data collector, which makes the policy play in the environment and +# delivers training data; +# - A replay buffer to store the training data; +# - A loss module, which computes the objective function to train our policy +# to maximise the return; +# - An optimizer, which performs parameter updates based on our loss. # -# We set the initial and final value of the epsilon factor in Epsilon-greedy -# exploration. -# Since our policy is deterministic, exploration is crucial: without it, the -# only source of randomness would be the environment reset. - -eps_greedy_val = 0.1 -eps_greedy_val_env = 0.005 - -############################################################################### -# To speed up learning, we set the bias of the last layer of our value network -# to a predefined value (this is not mandatory) -init_bias = 2.0 - -############################################################################### -# **Note**: for fast rendering of the tutorial ``total_frames`` hyperparameter -# was set to a very low number. To get a reasonable performance, use a greater -# value e.g. 500000 +# Additional modules include a logger, a recorder (executes the policy in +# "eval" mode) and a target network updater. With all these components into +# place, it is easy to see how one could misplace or misuse one component in +# the training script. The trainer is there to orchestrate everything for you! # # Building the environment # ------------------------ # -# Our environment builder has two arguments: -# -# - ``parallel``: determines whether multiple environments have to be run in -# parallel. We stack the transforms after the -# :class:`torchrl.envs.ParallelEnv` to take advantage -# of vectorization of the operations on device, although this would -# technically work with every single environment attached to its own set of -# transforms. -# - ``observation_norm_state_dict`` will contain the normalizing constants for -# the :class:`torchrl.envs.ObservationNorm` tranform. +# First let's write a helper function that will output an environment. As usual, +# the "raw" environment may be too simple to be used in practice and we'll need +# some data transformation to expose its output to the policy. # # We will be using five transforms: # -# - :class:`torchrl.envs.ToTensorImage` will convert a ``[W, H, C]`` uint8 +# - :class:`torchrl.envs.StepCounter` to count the number of steps in each trajectory; +# - :class:`torchrl.envs.transforms.ToTensorImage` will convert a ``[W, H, C]`` uint8 # tensor in a floating point tensor in the ``[0, 1]`` space with shape # ``[C, W, H]``; -# - :class:`torchrl.envs.RewardScaling` to reduce the scale of the return; -# - :class:`torchrl.envs.GrayScale` will turn our image into grayscale; -# - :class:`torchrl.envs.Resize` will resize the image in a 64x64 format; -# - :class:`torchrl.envs.CatFrames` will concatenate an arbitrary number of +# - :class:`torchrl.envs.transforms.RewardScaling` to reduce the scale of the return; +# - :class:`torchrl.envs.transforms.GrayScale` will turn our image into grayscale; +# - :class:`torchrl.envs.transforms.Resize` will resize the image in a 64x64 format; +# - :class:`torchrl.envs.transforms.CatFrames` will concatenate an arbitrary number of # successive frames (``N=4``) in a single tensor along the channel dimension. # This is useful as a single image does not carry information about the # motion of the cartpole. Some memory about past observations and actions # is needed, either via a recurrent neural network or using a stack of # frames. -# - :class:`torchrl.envs.ObservationNorm` which will normalize our observations +# - :class:`torchrl.envs.transforms.ObservationNorm` which will normalize our observations # given some custom summary statistics. # +# In practice, our environment builder has two arguments: +# +# - ``parallel``: determines whether multiple environments have to be run in +# parallel. We stack the transforms after the +# :class:`torchrl.envs.ParallelEnv` to take advantage +# of vectorization of the operations on device, although this would +# technically work with every single environment attached to its own set of +# transforms. +# - ``obs_norm_sd`` will contain the normalizing constants for +# the :class:`torchrl.envs.ObservationNorm` transform. +# -def make_env(parallel=False, observation_norm_state_dict=None): - if observation_norm_state_dict is None: - observation_norm_state_dict = {"standard_normal": True} +def make_env( + parallel=False, + obs_norm_sd=None, +): + if obs_norm_sd is None: + obs_norm_sd = {"standard_normal": True} if parallel: base_env = ParallelEnv( num_workers, EnvCreator( lambda: GymEnv( - "CartPole-v1", from_pixels=True, pixels_only=True, device=device + "CartPole-v1", + from_pixels=True, + pixels_only=True, + device=device, ) ), ) else: base_env = GymEnv( - "CartPole-v1", from_pixels=True, pixels_only=True, device=device + "CartPole-v1", + from_pixels=True, + pixels_only=True, + device=device, ) env = TransformedEnv( @@ -266,7 +223,7 @@ def make_env(parallel=False, observation_norm_state_dict=None): GrayScale(), Resize(64, 64), CatFrames(4, in_keys=["pixels"], dim=-3), - ObservationNorm(in_keys=["pixels"], **observation_norm_state_dict), + ObservationNorm(in_keys=["pixels"], **obs_norm_sd), ), ) return env @@ -274,25 +231,29 @@ def make_env(parallel=False, observation_norm_state_dict=None): ############################################################################### # Compute normalizing constants -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # To normalize images, we don't want to normalize each pixel independently # with a full ``[C, W, H]`` normalizing mask, but with simpler ``[C, 1, 1]`` -# shaped loc and scale parameters. We will be using the ``reduce_dim`` argument -# of :func:`torchrl.envs.ObservationNorm.init_stats` to instruct which +# shaped set of normalizing constants (loc and scale parameters). +# We will be using the ``reduce_dim`` argument +# of :meth:`torchrl.envs.ObservationNorm.init_stats` to instruct which # dimensions must be reduced, and the ``keep_dims`` parameter to ensure that # not all dimensions disappear in the process: +# -test_env = make_env() -test_env.transform[-1].init_stats( - num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2) -) -observation_norm_state_dict = test_env.transform[-1].state_dict() -############################################################################### -# let's check that normalizing constants have a size of ``[C, 1, 1]`` where -# ``C=4`` (because of :class:`torchrl.envs.CatFrames`). -print(observation_norm_state_dict) +def get_norm_stats(): + test_env = make_env() + test_env.transform[-1].init_stats( + num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2) + ) + obs_norm_sd = test_env.transform[-1].state_dict() + # let's check that normalizing constants have a size of ``[C, 1, 1]`` where + # ``C=4`` (because of :class:`torchrl.envs.CatFrames`). + print("state dict of the observation norm:", obs_norm_sd) + return obs_norm_sd + ############################################################################### # Building the model (Deep Q-network) @@ -305,37 +266,18 @@ def make_env(parallel=False, observation_norm_state_dict=None): # # .. math:: # -# val = b(obs) + v(obs) - \mathbb{E}[v(obs)] +# \mathbb{v} = b(obs) + v(obs) - \mathbb{E}[v(obs)] # -# where :math:`b` is a :math:`\# obs \rightarrow 1` function and :math:`v` is a -# :math:`\# obs \rightarrow num_actions` function. +# where :math:`\mathbb{v}` is our vector of action values, +# :math:`b` is a :math:`\mathbb{R}^n \rightarrow 1` function and :math:`v` is a +# :math:`\mathbb{R}^n \rightarrow \mathbb{R}^m` function, for +# :math:`n = \# obs` and :math:`m = \# actions`. # -# Our network is wrapped in a :class:`torchrl.modules.QValueActor`, which will read the state-action +# Our network is wrapped in a :class:`torchrl.modules.QValueActor`, +# which will read the state-action # values, pick up the one with the maximum value and write all those results # in the input :class:`tensordict.TensorDict`. # -# Target parameters -# ^^^^^^^^^^^^^^^^^ -# -# Many off-policy RL algorithms use the concept of "target parameters" when it -# comes to estimate the value of the ``t+1`` state or state-action pair. -# The target parameters are lagged copies of the model parameters. Because -# their predictions mismatch those of the current model configuration, they -# help learning by putting a pessimistic bound on the value being estimated. -# This is a powerful trick (known as "Double Q-Learning") that is ubiquitous -# in similar algorithms. -# -# Functionalizing modules -# ^^^^^^^^^^^^^^^^^^^^^^^ -# -# One of the features of torchrl is its usage of functional modules: as the -# same architecture is often used with multiple sets of parameters (e.g. -# trainable and target parameters), we functionalize the modules and isolate -# the various sets of parameters in separate tensordicts. -# -# To this aim, we use :func:`tensordict.nn.get_functional`, which augments -# our modules with some extra feature that make them compatible with parameters -# passed in the ``TensorDict`` format. def make_model(dummy_env): @@ -368,19 +310,6 @@ def make_model(dummy_env): tensordict = dummy_env.fake_tensordict() actor(tensordict) - # Make functional: - # here's an explicit way of creating the parameters and buffer tensordict. - # Alternatively, we could have used `params = make_functional(actor)` from - # tensordict.nn - params = TensorDict({k: v for k, v in actor.named_parameters()}, []) - buffers = TensorDict({k: v for k, v in actor.named_buffers()}, []) - params = params.update(buffers) - params = params.unflatten_keys(".") # creates a nested TensorDict - factor = get_functional(actor) - - # creating the target parameters is fairly easy with tensordict: - params_target = params.clone().detach() - # we wrap our actor in an EGreedyWrapper for data collection actor_explore = EGreedyWrapper( actor, @@ -389,43 +318,15 @@ def make_model(dummy_env): eps_end=eps_greedy_val_env, ) - return factor, actor, actor_explore, params, params_target + return actor, actor_explore -( - factor, - actor, - actor_explore, - params, - params_target, -) = make_model(test_env) - -############################################################################### -# We represent the parameters and targets as flat structures, but unflattening -# them is quite easy: - -params_flat = params.flatten_keys(".") - -############################################################################### -# We will be using the adam optimizer: - -optim = torch.optim.Adam(list(params_flat.values()), lr, betas=betas) - -############################################################################### -# We create a test environment for evaluation of the policy: - -test_env = make_env( - parallel=False, observation_norm_state_dict=observation_norm_state_dict -) -# sanity check: -print(actor_explore(test_env.reset())) - ############################################################################### # Collecting and storing data # --------------------------- # # Replay buffers -# ^^^^^^^^^^^^^^ +# ~~~~~~~~~~~~~~ # # Replay buffers play a central role in off-policy RL algorithms such as DQN. # They constitute the dataset we will be sampling from during training. @@ -441,17 +342,22 @@ def make_model(dummy_env): # The only requirement of this storage is that the data passed to it at write # time must always have the same shape. -replay_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(buffer_size), - prefetch=n_optim, -) + +def get_replay_buffer(buffer_size, n_optim, batch_size): + replay_buffer = TensorDictReplayBuffer( + batch_size=batch_size, + storage=LazyMemmapStorage(buffer_size), + prefetch=n_optim, + ) + return replay_buffer + ############################################################################### # Data collector -# ^^^^^^^^^^^^^^ +# ~~~~~~~~~~~~~~ # -# As in `PPO ` and -# `DDPG `, we will be using +# As in `PPO `_ and +# `DDPG `_, we will be using # a data collector as a dataloader in the outer loop. # # We choose the following configuration: we will be running a series of @@ -476,564 +382,328 @@ def make_model(dummy_env): # out training loop must account for. For simplicity, we set the devices to # the same value for all sub-collectors. -data_collector = MultiaSyncDataCollector( - # ``num_collectors`` collectors, each with an set of `num_workers` environments being run in parallel - [ - make_env( - parallel=True, observation_norm_state_dict=observation_norm_state_dict - ), - ] - * num_collectors, - policy=actor_explore, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - # this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode - exploration_mode="random", - # We set the all the devices to be identical. Below is an example of - # heterogeneous devices - devices=[device] * num_collectors, - storing_devices=[device] * num_collectors, - # devices=[f"cuda:{i}" for i in range(1, 1 + num_collectors)], - # storing_devices=[f"cuda:{i}" for i in range(1, 1 + num_collectors)], - split_trajs=False, -) + +def get_collector( + obs_norm_sd, + num_collectors, + actor_explore, + frames_per_batch, + total_frames, + device, +): + data_collector = MultiaSyncDataCollector( + [ + make_env(parallel=True, obs_norm_sd=obs_norm_sd), + ] + * num_collectors, + policy=actor_explore, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + # this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode + exploration_mode="random", + # We set the all the devices to be identical. Below is an example of + # heterogeneous devices + device=device, + storing_device=device, + split_trajs=False, + postproc=MultiStep(gamma=gamma, n_steps=5), + ) + return data_collector + ############################################################################### -# Training loop of a regular DQN -# ------------------------------ -# -# We'll start with a simple implementation of DQN where the returns are -# computed without bootstrapping, i.e. +# Loss function +# ------------- # -# .. math:: +# Building our loss function is straightforward: we only need to provide +# the model and a bunch of hyperparameters to the DQNLoss class. # -# Q_{t}(s, a) = R(s, a) + \gamma * V_{t+1}(s) +# Target parameters +# ~~~~~~~~~~~~~~~~~ # -# where :math:`Q(s, a)` is the Q-value of the current state-action pair, -# :math:`R(s, a)` is the result of the reward function, and :math:`V(s)` is a -# value function that returns 0 for terminating states. +# Many off-policy RL algorithms use the concept of "target parameters" when it +# comes to estimate the value of the next state or state-action pair. +# The target parameters are lagged copies of the model parameters. Because +# their predictions mismatch those of the current model configuration, they +# help learning by putting a pessimistic bound on the value being estimated. +# This is a powerful trick (known as "Double Q-Learning") that is ubiquitous +# in similar algorithms. # -# We store the logs in a defaultdict: -logs_exp1 = defaultdict(list) -prev_traj_count = 0 -pbar = tqdm.tqdm(total=total_frames) -for j, data in enumerate(data_collector): - current_frames = data.numel() - pbar.update(current_frames) - data = data.view(-1) +def get_loss_module(actor, gamma): + loss_module = DQNLoss(actor, gamma=gamma, delay_value=True) + target_updater = SoftUpdate(loss_module) + return loss_module, target_updater - # We store the values on the replay buffer, after placing them on CPU. - # When called for the first time, this will instantiate our storage - # object which will print its content. - replay_buffer.extend(data.cpu()) - # some logging - if len(logs_exp1["frames"]): - logs_exp1["frames"].append(current_frames + logs_exp1["frames"][-1]) - else: - logs_exp1["frames"].append(current_frames) +############################################################################### +# Hyperparameters +# --------------- +# +# Let's start with our hyperparameters. The following setting should work well +# in practice, and the performance of the algorithm should hopefully not be +# too sensitive to slight variations of these. - if data["next", "done"].any(): - done = data["next", "done"].squeeze(-1) - logs_exp1["traj_lengths"].append( - data["next", "step_count"][done].float().mean().item() - ) +device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu" - # check that we have enough data to start training - if sum(logs_exp1["frames"]) > init_random_frames: - for _ in range(n_optim): - # sample from the RB and send to device - sampled_data = replay_buffer.sample(batch_size) - sampled_data = sampled_data.to(device, non_blocking=True) - - # collect data from RB - reward = sampled_data["next", "reward"].squeeze(-1) - done = sampled_data["next", "done"].squeeze(-1).to(reward.dtype) - action = sampled_data["action"].clone() - - # Compute action value (of the action actually taken) at time t - # By default, TorchRL uses one-hot encodings for discrete actions - sampled_data_out = sampled_data.select(*actor.in_keys) - sampled_data_out = factor(sampled_data_out, params=params) - action_value = sampled_data_out["action_value"] - action_value = (action_value * action.to(action_value.dtype)).sum(-1) - with torch.no_grad(): - # compute best action value for the next step, using target parameters - tdstep = step_mdp(sampled_data) - next_value = factor( - tdstep.select(*actor.in_keys), - params=params_target, - )["chosen_action_value"].squeeze(-1) - exp_value = reward + gamma * next_value * (1 - done) - assert exp_value.shape == action_value.shape - # we use MSE loss but L1 or smooth L1 should also work - error = nn.functional.mse_loss(exp_value, action_value).mean() - error.backward() - - gv = nn.utils.clip_grad_norm_(list(params_flat.values()), 1) - - optim.step() - optim.zero_grad() - - # update of the target parameters - params_target.apply( - lambda p_target, p_orig: p_orig * tau + p_target * (1 - tau), - params.detach(), - inplace=True, - ) - - actor_explore.step(current_frames) - - # Logging - logs_exp1["grad_vals"].append(float(gv)) - logs_exp1["losses"].append(error.item()) - logs_exp1["values"].append(action_value.mean().item()) - logs_exp1["traj_count"].append( - prev_traj_count + data["next", "done"].sum().item() - ) - prev_traj_count = logs_exp1["traj_count"][-1] - - if j % 10 == 0: - with set_exploration_mode("mode"), torch.no_grad(): - # execute a rollout. The `set_exploration_mode("mode")` has no effect here since the policy is deterministic, but we add it for completeness - eval_rollout = test_env.rollout( - max_steps=10000, - policy=actor, - ).cpu() - logs_exp1["traj_lengths_eval"].append(eval_rollout.shape[-1]) - logs_exp1["evals"].append(eval_rollout["next", "reward"].sum().item()) - if len(logs_exp1["mavgs"]): - logs_exp1["mavgs"].append( - logs_exp1["evals"][-1] * 0.05 + logs_exp1["mavgs"][-1] * 0.95 - ) - else: - logs_exp1["mavgs"].append(logs_exp1["evals"][-1]) - logs_exp1["traj_count_eval"].append(logs_exp1["traj_count"][-1]) - pbar.set_description( - f"error: {error: 4.4f}, value: {action_value.mean(): 4.4f}, test return: {logs_exp1['evals'][-1]: 4.4f}" - ) +############################################################################### +# Optimizer +# ~~~~~~~~~ - # update policy weights - data_collector.update_policy_weights_() +# the learning rate of the optimizer +lr = 2e-3 +# weight decay +wd = 1e-5 +# the beta parameters of Adam +betas = (0.9, 0.999) +# Optimization steps per batch collected (aka UPD or updates per data) +n_optim = 8 + +############################################################################### +# DQN parameters +# ~~~~~~~~~~~~~~ +# gamma decay factor +gamma = 0.99 + +############################################################################### +# Smooth target network update decay parameter. +# This loosely corresponds to a 1/tau interval with hard target network +# update +tau = 0.02 ############################################################################### -# We write a custom plot function to display the performance of our algorithm +# Data collection and replay buffer +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# .. note:: +# Values to be used for proper training have been commented. # +# Total frames collected in the environment. In other implementations, the +# user defines a maximum number of episodes. +# This is harder to do with our data collectors since they return batches +# of N collected frames, where N is a constant. +# However, one can easily get the same restriction on number of episodes by +# breaking the training loop when a certain number +# episodes has been collected. +total_frames = 5_000 # 500000 +############################################################################### +# Random frames used to initialize the replay buffer. +init_random_frames = 100 # 1000 -def plot(logs, name): - plt.figure(figsize=(15, 10)) - plt.subplot(2, 3, 1) - plt.plot( - logs["frames"][-len(logs["evals"]) :], - logs["evals"], - label="return (eval)", - ) - plt.plot( - logs["frames"][-len(logs["mavgs"]) :], - logs["mavgs"], - label="mavg of returns (eval)", - ) - plt.xlabel("frames collected") - plt.ylabel("trajectory length (= return)") - plt.subplot(2, 3, 2) - plt.plot( - logs["traj_count"][-len(logs["evals"]) :], - logs["evals"], - label="return", - ) - plt.plot( - logs["traj_count"][-len(logs["mavgs"]) :], - logs["mavgs"], - label="mavg", - ) - plt.xlabel("trajectories collected") - plt.legend() - plt.subplot(2, 3, 3) - plt.plot(logs["frames"][-len(logs["losses"]) :], logs["losses"]) - plt.xlabel("frames collected") - plt.title("loss") - plt.subplot(2, 3, 4) - plt.plot(logs["frames"][-len(logs["values"]) :], logs["values"]) - plt.xlabel("frames collected") - plt.title("value") - plt.subplot(2, 3, 5) - plt.plot( - logs["frames"][-len(logs["grad_vals"]) :], - logs["grad_vals"], - ) - plt.xlabel("frames collected") - plt.title("grad norm") - if len(logs["traj_lengths"]): - plt.subplot(2, 3, 6) - plt.plot(logs["traj_lengths"]) - plt.xlabel("batches") - plt.title("traj length (training)") - plt.savefig(name) - if is_notebook(): - plt.show() +############################################################################### +# Frames in each batch collected. +frames_per_batch = 32 # 128 +############################################################################### +# Frames sampled from the replay buffer at each optimization step +batch_size = 32 # 256 ############################################################################### -# The performance of the policy can be measured as the length of trajectories. -# As we can see on the results of the :func:`plot` function, the performance -# of the policy increases, albeit slowly. -# -# .. code-block:: python -# -# plot(logs_exp1, "dqn_td0.png") -# -# .. figure:: /_static/img/dqn_td0.png -# :alt: Cart Pole results with TD(0) -# +# Size of the replay buffer in terms of frames +buffer_size = min(total_frames, 100000) -print("shutting down") -data_collector.shutdown() -del data_collector +############################################################################### +# Number of environments run in parallel in each data collector +num_workers = 2 # 8 +num_collectors = 2 # 4 ############################################################################### -# DQN with TD(:math:`\lambda`) -# ---------------------------- -# -# We can improve the above algorithm by getting a better estimate of the -# return, using not only the next state value but the whole sequence of rewards -# and values that follow a particular step. -# -# TorchRL provides a vectorized version of TD(lambda) named -# :func:`torchrl.objectives.value.functional.vec_td_lambda_advantage_estimate`. -# We'll use this to obtain a target value that the value network will be -# trained to match. +# Environment and exploration +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The big difference in this implementation is that we'll store entire -# trajectories and not single steps in the replay buffer. This will be done -# automatically as long as we're not "flattening" the tensordict collected: -# by keeping a shape ``[Batch x timesteps]`` and giving this -# to the RB, we'll be creating a replay buffer of size -# ``[Capacity x timesteps]``. +# We set the initial and final value of the epsilon factor in Epsilon-greedy +# exploration. +# Since our policy is deterministic, exploration is crucial: without it, the +# only source of randomness would be the environment reset. +eps_greedy_val = 0.1 +eps_greedy_val_env = 0.005 -from torchrl.objectives.value.functional import vec_td_lambda_advantage_estimate +############################################################################### +# To speed up learning, we set the bias of the last layer of our value network +# to a predefined value (this is not mandatory) +init_bias = 2.0 ############################################################################### -# We reset the actor parameters: +# .. note:: +# For fast rendering of the tutorial ``total_frames`` hyperparameter +# was set to a very low number. To get a reasonable performance, use a greater +# value e.g. 500000 # -( - factor, - actor, - actor_explore, - params, - params_target, -) = make_model(test_env) -params_flat = params.flatten_keys(".") - -optim = torch.optim.Adam(list(params_flat.values()), lr, betas=betas) -test_env = make_env( - parallel=False, observation_norm_state_dict=observation_norm_state_dict -) -print(actor_explore(test_env.reset())) - ############################################################################### -# Data: Replay buffer and collector -# --------------------------------- +# Building a Trainer +# ------------------ # -# We need to build a new replay buffer of the appropriate size: +# TorchRL's :class:`torchrl.trainers.Trainer` class constructor takes the +# following keyword-only arguments: # +# - ``collector`` +# - ``loss_module`` +# - ``optimizer`` +# - ``logger``: A logger can be +# - ``total_frames``: this parameter defines the lifespan of the trainer. +# - ``frame_skip``: when a frame-skip is used, the collector must be made +# aware of it in order to accurately count the number of frames +# collected etc. Making the trainer aware of this parameter is not +# mandatory but helps to have a fairer comparison between settings where +# the total number of frames (budget) is fixed but the frame-skip is +# variable. -max_size = frames_per_batch // num_workers +stats = get_norm_stats() +test_env = make_env(parallel=False, obs_norm_sd=stats) +# Get model +actor, actor_explore = make_model(test_env) +loss_module, target_net_updater = get_loss_module(actor, gamma) +target_net_updater.init_() -replay_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(-(-buffer_size // max_size)), - prefetch=n_optim, +collector = get_collector( + stats, num_collectors, actor_explore, frames_per_batch, total_frames, device ) - -data_collector = MultiaSyncDataCollector( - [ - make_env( - parallel=True, observation_norm_state_dict=observation_norm_state_dict - ), - ] - * num_collectors, - policy=actor_explore, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - exploration_mode="random", - devices=[device] * num_collectors, - storing_devices=[device] * num_collectors, - # devices=[f"cuda:{i}" for i in range(1, 1 + num_collectors)], - # storing_devices=[f"cuda:{i}" for i in range(1, 1 + num_collectors)], - split_trajs=False, +optimizer = torch.optim.Adam( + loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas ) - - -logs_exp2 = defaultdict(list) -prev_traj_count = 0 +exp_name = f"dqn_exp_{uuid.uuid1()}" +tmpdir = tempfile.TemporaryDirectory() +logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name) +warnings.warn(f"log dir: {logger.experiment.log_dir}") ############################################################################### -# Training loop -# ------------- -# -# There are very few differences with the training loop above: -# -# - The tensordict received by the collector is used as-is, without being -# flattened (recall the ``data.view(-1)`` above), to keep the temporal -# relation between consecutive steps. -# - We use :func:`vec_td_lambda_advantage_estimate` to compute the target -# value. - -pbar = tqdm.tqdm(total=total_frames) -for j, data in enumerate(data_collector): - current_frames = data.numel() - pbar.update(current_frames) - - replay_buffer.extend(data.cpu()) - if len(logs_exp2["frames"]): - logs_exp2["frames"].append(current_frames + logs_exp2["frames"][-1]) - else: - logs_exp2["frames"].append(current_frames) - - if data["next", "done"].any(): - done = data["next", "done"].squeeze(-1) - logs_exp2["traj_lengths"].append( - data["next", "step_count"][done].float().mean().item() - ) - - if sum(logs_exp2["frames"]) > init_random_frames: - for _ in range(n_optim): - sampled_data = replay_buffer.sample(batch_size // max_size) - sampled_data = sampled_data.clone().to(device, non_blocking=True) - - reward = sampled_data["next", "reward"] - done = sampled_data["next", "done"].to(reward.dtype) - action = sampled_data["action"].clone() - - sampled_data_out = sampled_data.select(*actor.in_keys) - sampled_data_out = vmap(factor, (0, None))(sampled_data_out, params) - action_value = sampled_data_out["action_value"] - action_value = (action_value * action.to(action_value.dtype)).sum(-1, True) - with torch.no_grad(): - tdstep = step_mdp(sampled_data) - next_value = vmap(factor, (0, None))( - tdstep.select(*actor.in_keys), params - ) - next_value = next_value["chosen_action_value"] - error = vec_td_lambda_advantage_estimate( - gamma, - lmbda, - action_value, - next_value, - reward, - done, - time_dim=sampled_data_out.ndim - 1, - ).pow(2) - error = error.mean() - error.backward() - - gv = nn.utils.clip_grad_norm_(list(params_flat.values()), 1) - - optim.step() - optim.zero_grad() - - # update of the target parameters - params_target.apply( - lambda p_target, p_orig: p_orig * tau + p_target * (1 - tau), - params.detach(), - inplace=True, - ) - - actor_explore.step(current_frames) - - # Logging - logs_exp2["grad_vals"].append(float(gv)) - - logs_exp2["losses"].append(error.item()) - logs_exp2["values"].append(action_value.mean().item()) - logs_exp2["traj_count"].append( - prev_traj_count + data["next", "done"].sum().item() - ) - prev_traj_count = logs_exp2["traj_count"][-1] - if j % 10 == 0: - with set_exploration_mode("mode"), torch.no_grad(): - # execute a rollout. The `set_exploration_mode("mode")` has - # no effect here since the policy is deterministic, but we add - # it for completeness - eval_rollout = test_env.rollout( - max_steps=10000, - policy=actor, - ).cpu() - logs_exp2["traj_lengths_eval"].append(eval_rollout.shape[-1]) - logs_exp2["evals"].append(eval_rollout["next", "reward"].sum().item()) - if len(logs_exp2["mavgs"]): - logs_exp2["mavgs"].append( - logs_exp2["evals"][-1] * 0.05 + logs_exp2["mavgs"][-1] * 0.95 - ) - else: - logs_exp2["mavgs"].append(logs_exp2["evals"][-1]) - logs_exp2["traj_count_eval"].append(logs_exp2["traj_count"][-1]) - pbar.set_description( - f"error: {error: 4.4f}, value: {action_value.mean(): 4.4f}, test return: {logs_exp2['evals'][-1]: 4.4f}" - ) +# We can control how often the scalars should be logged. Here we set this +# to a low value as our training loop is short: - # update policy weights - data_collector.update_policy_weights_() +log_interval = 500 +trainer = Trainer( + collector=collector, + total_frames=total_frames, + frame_skip=1, + loss_module=loss_module, + optimizer=optimizer, + logger=logger, + optim_steps_per_batch=n_optim, + log_interval=log_interval, +) ############################################################################### -# TD(:math:`\lambda`) performs significantly better than TD(0) because it -# retrieves a much less biased estimate of the state-action value. -# -# .. code-block:: python +# Registering hooks +# ~~~~~~~~~~~~~~~~~ # -# plot(logs_exp2, "dqn_tdlambda.png") +# Registering hooks can be achieved in two separate ways: # -# .. figure:: /_static/img/dqn_tdlambda.png -# :alt: Cart Pole results with TD(lambda) -# - - -print("shutting down") -data_collector.shutdown() -del data_collector +# - If the hook has it, the :meth:`torchrl.trainers.TrainerHookBase.register` +# method is the first choice. One just needs to provide the trainer as input +# and the hook will be registered with a default name at a default location. +# For some hooks, the registration can be quite complex: :class:`torchrl.trainers.ReplayBufferTrainer` +# requires 3 hooks (``extend``, ``sample`` and ``update_priority``) which +# can be cumbersome to implement. +buffer_hook = ReplayBufferTrainer( + get_replay_buffer(buffer_size, n_optim, batch_size=batch_size), + flatten_tensordicts=True, +) +buffer_hook.register(trainer) +weight_updater = UpdateWeights(collector, update_weights_interval=1) +weight_updater.register(trainer) +recorder = Recorder( + record_interval=100, # log every 100 optimization steps + record_frames=1000, # maximum number of frames in the record + frame_skip=1, + policy_exploration=actor_explore, + environment=test_env, + exploration_mode="mode", + log_keys=[("next", "reward")], + out_keys={("next", "reward"): "rewards"}, + log_pbar=True, +) +recorder.register(trainer) ############################################################################### -# Let's compare the results on a single plot. Because the TD(lambda) version -# works better, we'll have fewer episodes collected for a given number of -# frames (as there are more frames per episode). +# - Any callable (including :class:`torchrl.trainers.TrainerHookBase` +# subclasses) can be registered using :meth:`torchrl.trainers.Trainer.register_op`. +# In this case, a location must be explicitly passed (). This method gives +# more control over the location of the hook but it also requires more +# understanding of the Trainer mechanism. +# Check the `trainer documentation `_ +# for a detailed description of the trainer hooks. # -# **Note**: As already mentioned above, to get a more reasonable performance, -# use a greater value for ``total_frames`` e.g. 500000. - - -def plot_both(): - frames_td0 = logs_exp1["frames"] - frames_tdlambda = logs_exp2["frames"] - evals_td0 = logs_exp1["evals"] - evals_tdlambda = logs_exp2["evals"] - mavgs_td0 = logs_exp1["mavgs"] - mavgs_tdlambda = logs_exp2["mavgs"] - traj_count_td0 = logs_exp1["traj_count_eval"] - traj_count_tdlambda = logs_exp2["traj_count_eval"] - - plt.figure(figsize=(15, 10)) - plt.subplot(1, 2, 1) - plt.plot(frames_td0[-len(evals_td0) :], evals_td0, label="return (td0)", alpha=0.5) - plt.plot( - frames_tdlambda[-len(evals_tdlambda) :], - evals_tdlambda, - label="return (td(lambda))", - alpha=0.5, - ) - plt.plot(frames_td0[-len(mavgs_td0) :], mavgs_td0, label="mavg (td0)") - plt.plot( - frames_tdlambda[-len(mavgs_tdlambda) :], - mavgs_tdlambda, - label="mavg (td(lambda))", - ) - plt.xlabel("frames collected") - plt.ylabel("trajectory length (= return)") - - plt.subplot(1, 2, 2) - plt.plot( - traj_count_td0[-len(evals_td0) :], - evals_td0, - label="return (td0)", - alpha=0.5, - ) - plt.plot( - traj_count_tdlambda[-len(evals_tdlambda) :], - evals_tdlambda, - label="return (td(lambda))", - alpha=0.5, - ) - plt.plot(traj_count_td0[-len(mavgs_td0) :], mavgs_td0, label="mavg (td0)") - plt.plot( - traj_count_tdlambda[-len(mavgs_tdlambda) :], - mavgs_tdlambda, - label="mavg (td(lambda))", - ) - plt.xlabel("trajectories collected") - plt.legend() - - plt.savefig("dqn.png") - +trainer.register_op("post_optim", target_net_updater.step) ############################################################################### -# .. code-block:: python -# -# plot_both() -# -# .. figure:: /_static/img/dqn.png -# :alt: Cart Pole results from the TD(:math:`lambda`) trained policy. +# We can log the training rewards too. Note that this is of limited interest +# with CartPole, as rewards are always 1. The discounted sum of rewards is +# maximised not by getting higher rewards but by keeping the cart-pole alive +# for longer. +# This will be reflected by the `total_rewards` value displayed in the +# progress bar. # -# Finally, we generate a new video to check what the algorithm has learnt. -# If all goes well, the duration should be significantly longer than with a -# random rollout. +log_reward = LogReward(log_pbar=True) +log_reward.register(trainer) + +############################################################################### +# .. note:: +# It is possible to link multiple optimizers to the trainer if needed. +# In this case, each optimizer will be tied to a field in the loss +# dictionary. +# Check the :class:`torchrl.trainers.OptimizerHook` to learn more. # -# To get the raw pixels of the rollout, we insert a -# :class:`torchrl.envs.CatTensors` transform that precedes all others and copies -# the ``"pixels"`` key onto a ``"pixels_save"`` key. This is necessary because -# the other transforms that modify this key will update its value in-place in -# the output tensordict. +# Here we are, ready to train our algorithm! A simple call to +# ``trainer.train()`` and we'll be getting our results logged in. # +trainer.train() -test_env.transform.insert(0, CatTensors(["pixels"], "pixels_save", del_keys=False)) -eval_rollout = test_env.rollout(max_steps=10000, policy=actor, auto_reset=True).cpu() +############################################################################### +# We can now quickly check the CSVs with the results. -# sphinx_gallery_start_ignore -import imageio -imageio.mimwrite("cartpole.gif", eval_rollout["pixels_save"].numpy(), fps=30) -# sphinx_gallery_end_ignore +def print_csv_files_in_folder(folder_path): + """ + Find all CSV files in a folder and prints the first 10 lines of each file. -del test_env + Args: + folder_path (str): The relative path to the folder. + + """ + csv_files = [] + output_str = "" + for dirpath, _, filenames in os.walk(folder_path): + for file in filenames: + if file.endswith(".csv"): + csv_files.append(os.path.join(dirpath, file)) + for csv_file in csv_files: + output_str += f"File: {csv_file}\n" + with open(csv_file, "r") as f: + for i, line in enumerate(f): + if i == 10: + break + output_str += line.strip() + "\n" + output_str += "\n" + print(output_str) -############################################################################### -# The video of the rollout can be saved using the imageio package: -# -# .. code-block:: -# -# import imageio -# imageio.mimwrite('cartpole.mp4', eval_rollout["pixels_save"].numpy(), fps=30); -# -# .. figure:: /_static/img/cartpole.gif -# :alt: Cart Pole results from the TD(:math:`\lambda`) trained policy. + +print_csv_files_in_folder(logger.experiment.log_dir) ############################################################################### # Conclusion and possible improvements # ------------------------------------ # -# In this tutorial we have learnt: +# In this tutorial we have learned: # -# - How to train a policy that read pixel-based states, what transforms to -# include and how to normalize the data; -# - How to create a policy that picks up the action with the highest value -# with :class:`torchrl.modules.QValueNetwork`; +# - How to write a Trainer, including building its components and registering +# them in the trainer; +# - How to code a DQN algorithm, including how to create a policy that picks +# up the action with the highest value with +# :class:`torchrl.modules.QValueNetwork`; # - How to build a multiprocessed data collector; -# - How to train a DQN with TD(:math:`\lambda`) returns. -# -# We have seen that using TD(:math:`\lambda`) greatly improved the performance -# of DQN. Other possible improvements could include: -# -# - Using the Multi-Step post-processing. Multi-step will project an action -# to the nth following step, and create a discounted sum of the rewards in -# between. This trick can make the algorithm noticebly less myopic. To use -# this, simply create the collector with -# -# from torchrl.data.postprocs.postprocs import MultiStep -# collector = CollectorClass(..., postproc=MultiStep(gamma, n)) # -# where ``n`` is the number of looking-forward steps. Pay attention to the -# fact that the ``gamma`` factor has to be corrected by the number of -# steps till the next observation when being passed to -# ``vec_td_lambda_advantage_estimate``: +# Possible improvements to this tutorial could include: # -# gamma = gamma ** tensordict["steps_to_next_obs"] # - A prioritized replay buffer could also be used. This will give a # higher priority to samples that have the worst value accuracy. -# - A distributional loss (see ``torchrl.objectives.DistributionalDQNLoss`` +# Learn more on the +# `replay buffer section `_ +# of the documentation. +# - A distributional loss (see :class:`torchrl.objectives.DistributionalDQNLoss` # for more information). -# - More fancy exploration techniques, such as NoisyLinear layers and such -# (check ``torchrl.modules.NoisyLinear``, which is fully compatible with the -# ``MLP`` class used in our Dueling DQN). +# - More fancy exploration techniques, such as :class:`torchrl.modules.NoisyLinear` layers and such. diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 77ed207837f..274269a3dac 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -602,7 +602,8 @@ # We'll need an "advantage" signal to make PPO work. # We re-compute it at each epoch as its value depends on the value # network which is updated in the inner loop. - advantage_module(tensordict_data) + with torch.no_grad(): + advantage_module(tensordict_data) data_view = tensordict_data.reshape(-1) replay_buffer.extend(data_view.cpu()) for _ in range(frames_per_batch // sub_batch_size):