diff --git a/src/cdk/tree/toggle.ts b/src/cdk/tree/toggle.ts index b7004e60f1cb..5853bf64f937 100644 --- a/src/cdk/tree/toggle.ts +++ b/src/cdk/tree/toggle.ts @@ -39,11 +39,17 @@ export class CdkTreeNodeToggle { constructor(protected _tree: CdkTree, protected _treeNode: CdkTreeNode) {} + // Toggle the expanded or collapsed state of this node. + // + // Focus this node with expanding or collapsing it. This ensures that the active node will always + // be visible when expanding and collapsing. _toggle(event: Event): void { this.recursive ? this._tree.toggleDescendants(this._treeNode.data) : this._tree.toggle(this._treeNode.data); + this._tree._keyManager.focusItem(this._treeNode); + event.stopPropagation(); } diff --git a/src/cdk/tree/tree-redesign.spec.ts b/src/cdk/tree/tree-redesign.spec.ts index 13f0098b745c..a43a1dc4e999 100644 --- a/src/cdk/tree/tree-redesign.spec.ts +++ b/src/cdk/tree/tree-redesign.spec.ts @@ -270,6 +270,28 @@ describe('CdkTree redesign', () => { .toBe(0); }); + it('should focus a node when collapsing it', () => { + // Create a tree with two nodes. A parent node and its child. + dataSource.clear(); + const parent = dataSource.addData(); + dataSource.addChild(parent); + + component.tree.expandAll(); + fixture.detectChanges(); + + // focus the child node + getNodes(treeElement)[1].click(); + fixture.detectChanges(); + + // collapse the parent node + getNodes(treeElement)[0].click(); + fixture.detectChanges(); + + expect(getNodes(treeElement).map(x => x.getAttribute('tabindex'))) + .withContext(`Expecting parent node to be focused since it was collapsed.`) + .toEqual(['0', '-1']); + }); + it('should expand/collapse the node recursively', () => { expect(dataSource.data.length).toBe(3); @@ -1312,15 +1334,21 @@ class FakeDataSource extends DataSource { return child; } - addData(level: number = 1) { + addData(level: number = 1): TestData { const nextIndex = ++this.dataIndex; let copiedData = this.data.slice(); - copiedData.push( - new TestData(`topping_${nextIndex}`, `cheese_${nextIndex}`, `base_${nextIndex}`, level), + const newData = new TestData( + `topping_${nextIndex}`, + `cheese_${nextIndex}`, + `base_${nextIndex}`, + level, ); + copiedData.push(newData); this.data = copiedData; + + return newData; } getRecursiveData(nodes: TestData[] = this._dataChange.getValue()): TestData[] { @@ -1328,6 +1356,11 @@ class FakeDataSource extends DataSource { ...new Set(nodes.flatMap(parent => [parent, ...this.getRecursiveData(parent.children)])), ]; } + + clear() { + this.data = []; + this.dataIndex = 0; + } } function getNodes(treeElement: Element): HTMLElement[] {