diff --git a/src/benchmark.ts b/src/benchmark.ts new file mode 100644 index 00000000..dc1235fe --- /dev/null +++ b/src/benchmark.ts @@ -0,0 +1,132 @@ +// This was restored from an old version of index.ts +// This is not cleaned up at all and could use some love. + + +// tslint:disable-next-line no-require-imports no-var-requires +import program from 'commander'; +import * as path from 'path'; +import * as Transform from './transform'; +import { toolDefaults, benchmarkDefaults, Browser} from './types'; +import { compileToStringSync } from 'node-elm-compiler'; +import * as fs from 'fs'; +import chalk from 'chalk'; +const { version } = require('../package.json'); +import * as BenchInit from './benchmark/init' +import * as Benchmark from './benchmark/benchmark'; +import * as Reporting from './benchmark/reporting'; +import { readFilesSync } from './fs_util'; + +program + .version(version) + .description( + `${chalk.yellow('Elm Optimize Level 2!')} + +This applies a second level of optimization to the javascript that Elm creates. + +Make sure you're familiar with Elm's built-in optimization first: ${chalk.cyan( + 'https://guide.elm-lang.org/optimization/asset_size.html' + )} + +Give me an Elm file, I'll compile it behind the scenes using Elm 0.19.1, and then I'll make some more optimizations!` + ) + .usage('[options] ') + .option('--output ', 'the javascript file to create.', 'elm.js') + .option('-O3, --optimize-speed', 'Enable optimizations that likely increases asset size', false) + .option('--init-benchmark ', 'Generate some files to help run benchmarks') + .option('--benchmark ', 'Run the benchmark in the given directory.') + .option('--replacements ', 'Replace stuff') + .parse(process.argv); + +async function run(inputFilePath: string | undefined) { + const dirname = process.cwd(); + let jsSource: string = ''; + let elmFilePath = undefined; + + const options = program.opts(); + const replacements = options.replacements; + const o3Enabled = options.optimizeSpeed; + + if (program.initBenchmark) { + console.log(`Initializing benchmark ${program.initBenchmark}`) + BenchInit.generate(program.initBenchmark) + process.exit(0) + } + + if (program.benchmark) { + const options = { + compile: true, + gzip: true, + minify: true, + verbose: true, + assetSizes: true, + runBenchmark: [ + { + browser: Browser.Chrome, + headless: true, + } + ], + transforms: benchmarkDefaults(o3Enabled, replacements), + }; + const report = await Benchmark.run(options, [ + { + name: 'Benchmark', + dir: program.benchmark, + elmFile: 'V8/Benchmark.elm', + } + ]); + console.log(Reporting.terminal(report)); +// fs.writeFileSync('./results.markdown', Reporting.markdownTable(result)); + process.exit(0) + } + + if (inputFilePath && inputFilePath.endsWith('.js')) { + jsSource = fs.readFileSync(inputFilePath, 'utf8'); + console.log('Optimizing existing JS...'); + } else if (inputFilePath && inputFilePath.endsWith('.elm')) { + elmFilePath = inputFilePath; + jsSource = compileToStringSync([inputFilePath], { + output: 'output/elm.opt.js', + cwd: dirname, + optimize: true, + processOpts: + // ignore stdout + { + stdio: ['inherit', 'ignore', 'inherit'], + }, + }); + if (jsSource != '') { + console.log('Compiled, optimizing JS...'); + } else { + process.exit(1) + } + } else { + console.error('Please provide a path to an Elm file.'); + program.outputHelp(); + return; + } + if (jsSource != '') { + const transformed = await Transform.transform( + dirname, + jsSource, + elmFilePath, + false, + toolDefaults(o3Enabled, replacements), + ); + + // Make sure all the folders up to the output file exist, if not create them. + // This mirrors elm make behavior. + const outputDirectory = path.dirname(program.output); + if (!fs.existsSync(outputDirectory)) { + fs.mkdirSync(outputDirectory, { recursive: true }); + } + fs.writeFileSync(program.output, transformed); + const fileName = path.basename(inputFilePath); + console.log('Success!'); + console.log(''); + console.log(` ${fileName} ───> ${program.output}`); + console.log(''); + } +} + + +run(program.args[0]).catch((e) => console.error(e)); diff --git a/src/transform.ts b/src/transform.ts index 36dcc222..8feef850 100644 --- a/src/transform.ts +++ b/src/transform.ts @@ -25,6 +25,7 @@ import { inlineNumberToString } from './transforms/inlineNumberToString'; import { reportFunctionStatusInBenchmarks, v8Debug } from './transforms/analyze'; import { recordUpdate } from './transforms/recordUpdate'; import * as Replace from './transforms/replace'; +import { createTailCallRecursionTransformer } from './transforms/tailCallRecursion'; export type Options = { compile: boolean; @@ -87,6 +88,7 @@ export const transform = async ( let inlineCtx: InlineContext | undefined; const transformations: any[] = removeDisabled([ [transforms.replacements != null, replacementTransformer ], + [transforms.tailCallRecursion, createTailCallRecursionTransformer(false) ], [transforms.fastCurriedFns, Replace.from_file('/../replacements/faster-function-wrappers') ], [transforms.replaceListFunctions, Replace.from_file('/../replacements/list') ], [transforms.replaceStringFunctions, Replace.from_file('/../replacements/string') ], diff --git a/src/transforms/tailCallRecursion.ts b/src/transforms/tailCallRecursion.ts new file mode 100644 index 00000000..430ed3cf --- /dev/null +++ b/src/transforms/tailCallRecursion.ts @@ -0,0 +1,2437 @@ +import ts from 'typescript'; +import { ast } from './utils/create'; +import { determineType, PossibleReturnType } from './utils/determineType'; + +/* + +Applies "Tail Recursion Modulo Cons" (TMRC) when possible, and simple TCR where the compiler didn't. + + +# TCR improvements over the compiler output + +## TCR failures because of pipelines + +This function gets tail-call optimized. + + tco : (a -> b) -> List a -> List b -> List b + tco mapper list acc = + case list of + [] -> + acc + + x :: xs -> + tco mapper xs (mapper x :: acc) + +but this version doesn't (because of the additional `<|`): + + nonTco : (a -> b) -> List a -> List b -> List b + nonTco mapper list acc = + case list of + [] -> + acc + + x :: xs -> + nonTco mapper xs <| (mapper x :: acc) + + +## Recursion inside a boolean expression + +This function gets tail-call optimized. + + naiveAll : (a -> Bool) -> List a -> Bool + naiveAll isOkay list = + case list of + [] -> + True + + x :: xs -> + isOkay x && naiveAll isOkay xs + +just like if it was written as + + naiveAll : (a -> Bool) -> List a -> Bool + naiveAll isOkay list = + case list of + [] -> + True + + x :: xs -> + if isOkay x then + True + + else + naiveAll isOkay xs + + +## Tail Recursion Modulo Cons + +"Tail Recursion Modulo Cons" (TRMC) is an extension to TCR/TCO that supports making some operations on the result of +recursive calls optimized into a while loop, which makes the recursive function stack-safe and more performant. + +made it into OCaml in 2022, and a paper describing the why and how +is available here: https://arxiv.org/pdf/2102.09823.pdf. + +This implementation and the one in the OCaml were done independently (I hadn't read the paper when I did +most of the implementation), but I think they're sufficiently close that the paper makes for a deeper explanation +than what I will do here. There is a section further below going detailing the differences between the two implementations. + +The Elm compiler does tail-recursive call optimization, as explained in detail +in https://jfmengels.net/tail-call-optimization/. This optimization is great because it both +prevents stack-overflow errors and improves performance. + +Unfortunately, the knowledge of how to write stack-safe recursive functions does not come +naturally to developers and requires first encountering runtime errors. + +TRMC helps with that by expanding the number of situations where functions benefit from tail-call +optimization including almost textbook examples of recursion (that are generally not stack-safe), reducing the number of stack overflows that compiled Elm code will create. + +I have written an `elm-review` rule (https://package.elm-lang.org/packages/jfmengels/elm-review-performance/latest/NoUnoptimizedRecursion) +to detect non-recursive functions, and a lot of the issues would be resolved by TMRC (and this transformer, re: boolean and pipe recursions). + +If we take the errors reported by the rule in [`elm-community/list-extra`](https://github.com/elm-community/list-extra/blob/8.3.0/src/List/Extra.elm) +v8.3.0 (some functions were rewritten later based on the rule's reports), then we find that there are 13 functions +that are not stack-safe. Of those and with this transformer, 1 would become optimized because of the support for +pipeline operators, and 6 would become stack-safe without any rewrite. + +Functions that were written to be TCR-compliant often incur a code-complexity cost, and TRMC help reduce that. +Here are functions rewritten assuming TRMC: +- `elm/core`: https://github.com/jfmengels/core/compare/2fa34772a2575d036c0871b4390379741e6f5f91...new-tail-recursion?diff=split +- `elm-community/list-extra`: https://github.com/jfmengels/list-extra/compare/2e70e94278dd54e0d7f24bc6fdcbd5a2c6ff5c00..new-tail-recursion?diff=split + + +## Supported kinds of recursion + +This implementation supports the following recursion constructs. Examples for how the code looks before and after +transformation can be found in the tests. + +### Data construction + +```elm +type List a + = Nil + | Cons a (List a) + +map : (a -> b) -> List a -> List b +map fn list = + case list of + Nil -> + Nil + + Cons x rest -> + Cons (fn x) (map fn rest) +``` + + +### Nested data construction (NOTE: not implemented yet) + +Same as before, but for when data is wrapped in an arbitrary number of data wrappers (whose data structures is predictable). + +```elm +type List a + = Nil + | Cons { data : a, list : List a } + +map : (a -> b) -> List a -> List b +map fn list = + case list of + Nil -> + Nil + + Cons x rest -> + Cons { data = fn x, list = map fn rest } +``` + +### List constructions + +Recursions like `x :: rec y`, `x ++ rec y` and `rec y ++ x` (or using `List.append`) are optimized, as long as we +can infer somewhere that the function returns lists (for concatenation). + +### Arithmetic operators + +Recursions like `x + rec y`, `rec y + x`, `rec y * x` and `x * rec y` are optimized, as long as they're not mixed. + +### String constructions + +Recursions like `x ++ rec y` and `rec y ++ x`are optimized, as long as we can infer somewhere +that the function returns strings. + + +### Note on determining which the recursion type + +Because `+` in the compiled code is used for both addition and string concatenation, and because `_Utils_ap` is the function +used for both concatenating strings and concatenating lists, we unfortunately need to figure out the type of the function +to choose the right one. + +When strings are appended without literals, `_Utils_ap` is used (which is also used for list concatenation) +- `foo ++ "bar"` => `foo + 'bar'` +- `foo ++ bar ++ "bar"` => `foo + (bar + 'bar')` +- `(foo ++ bar) ++ "bar"` => `_Utils_ap(foo, bar) + 'bar'` + +If we see a `+` operation, we can look at the operands. If there is a string, it's a string append, and otherwise it's a number sum. + +If we see a `_Utils_ap` + - If there was a `+` somewhere else (in a return statement), then we can determine it's a string concatenation + - Otherwise we don't know and we don't do anything + +We could try and do more inference but we don't want to reimplement a type checker here. Also, we could potentially +inspect the Elm code or the elmi/elmo files inside `elm-stuff`, but these would not work (as well) for let functions, +and would not always be available to this `elm-optimize-level-2`. + + +## How does this transformer work? + +On a high level, this transformer analyzes functions to find if they're recursive and optimizable. + +It does one visit of the AST to analyze the `return` expressions of functions to try and find recursive calls, +allowing it to determine what kind of optimization strategy to adopt. + +It then does a second pass to alter the function. It rewrites the body of the function to use iteration (using a loop) +and changes the `return` statements to `continue` statements with some additional variable manipulations. + +Let's go into the details. + +### Analyzing return statements + +To know whether a function is recursive, we need to look at the return statements to find "local recursion patterns". + +If we have a function named `rec`, then a call like `rec y` is a "plain" recursive call. +The analysis tries to find additional patterns that we know are potentially optimizable (the examples are Elm code): +- Boolean recursion: `fn x || rec y` or `fn x && rec y` +- Cons recursion: `x :: rec y` +- Addition recursion: `x + rec y` (can be used for both numbers and strings) +- Multiplication recursion: `x * rec y +- Concatenation recursion: `x ++ rec y` or `rec y ++ x` using the JS `_Utils_ap` function (can be used for both strings and lists) +- Data construction recursion: `Cons x (rec y` + +### Combining local recursion patterns to find the function recursion type + +The way we are going to update the `return` statements depends mostly on the local recursion patterns, but the way we change the "outer body" of the function +depends on the combination of those, which we are going to distill into a "function recursion type". + +For instance, if we have the following code: +```elm +sum : List number -> number +sum list = + case list of + [] -> + 0 + + x :: xs -> + x + sum xs +``` +```js +var sum = function (list) { + if (!list.b) { + return 0; + } else { + var x = list.a; + var xs = list.b; + return x + sum(xs); + } +}; +``` +then we will transform it to the following: +```js +var sum = function (list) { + var $result = 0; // Change dependent on the function recursion type + sum: while (true) { + if (!list.b) { + // Change dependent on the local recursion pattern + // (in practice we remove the + 0) + return $result + 0; + } else { + var x = list.a; + var xs = list.b; + // Change dependent on the local recursion pattern + $result += x; + list = xs; + continue sum; + } + } +}; +``` + +While plain recursion and boolean recursion calls are always optimizable, not all of the others are without some more information or intersecting information. + +For instance, when we find concatenation recursion calls (like `x ++ rec (n - 1)`), we don't have enough information to optimize this. +For this optimization in particular, we need to create an initial value to which to append all the `return` expression, but since we don't +know if we're dealing with strings (initial value `""`) or lists (initial value `[]`). + +This is also the case for the addition recursion, where we don't know if the JS `+` is for adding numbers or strings. +Though in that case, we can rely on the fact that the compiler (at least until 0.19.1) won't use `+` unless there is a literal string somewhere, +so we can assume we're adding numbers unless we found such a literal string. + +So in essence, we need to combine the different local recursion patterns to find which kind of function recursion optimization we're going to apply. +For instance, if in one branch we see an addition recursion call, and in another one we see concatenation recursion, we can infer that we're dealing with strings, +that the function type should be string concatenation recursion and that the initial value for the accumulator should be `""` (instead of 0 or `[]`). + +We will also use trivial type inference on (recursive and non-recursive) return statements to figure out the missing bits. If we see a `return "";` somewhere, +we know by the fact that all return statements return the same type, that we're dealing with strings and not numbers or lists, which can help determine whether +we need to do list concatenation recursion or string concatenation recursion. + +Once we have figured the exact function recursion type, we can stop the analysis and start transforming the body and the return statements. + +### Transforming the body + +Once we detect recursion, we know that we will need to have a while loop. Because the Elm compiler already introduces while loops for plain recursive calls +(modulo some issues with piping), we will just need to make sure we don't introduce a second one. + +Depending on the function recursion type, we also need to add accumulator variables to help us accumulate the results of the recursive calls. +For example and as shown in the previous `sum` example, the initial value will be an accumulator holding the value `0`, and for multiplication it will be `1`. + +For strings, dependent on whether we find `foo ++ rec (n - 1)` and/or `rec (n - 1) ++ foo` +(both can be found in the same function, and we can also find `foo ++ rec (n - 1) ++ bar`), +we will add a variable `$left` and/or `$right` containing `""`. + + +### Transforming the return statements + +Already present `continue` statements that are introduced by the Elm compiler aren't touched and don't need to be altered. + +The return statement updates will highly depend on the function recursion type. + +For plain recursion functions, we leave the non-recursive calls untouched and only touch the (plain) recursive calls. + +For boolean recursion functions/calls, we do the same, except that we transform `fn x || rec (n - 1)` into `if (fn(x)) { return true; } else { <...> ; continue; }`. + +For the other recursive function types, we will basically do the same operations for recursive calls, along with a few operations to mutate the accumulator. +For instance, when encountering `return x + rec y`, we will transform that to the following: + +```js +$left += x; // Updates the accumulator +n = n - 1; // Changes the arguments to the function, like for plain recursion +continue rec; // Re-start the loop +``` + +If we have a non-recursive call, then we need to update the return value to include the accumulator. +For an addition recursion, `return x` would be transformed to `return $left + x`. + + +## Differences with the OCaml implementation + +This idea and implementation were not made based on the paper, but we came to most of the same conclusions, +but a few differences remain. + +1) The OCaml implementation chose to split TMR functions into a base version and a regular "DPS" version, +to avoid the extra cost of an additional function call. Our implementation only updates the body to add the +necessary accumulators to avoid the extra cost of an additional function call. From my benchmarking, it's not +clear which version is faster, but our version makes for a smaller bundle size. + +2) The OCaml implementors chose to make the optimization opt-in. While I do understand adding annotations +to get compiler errors when an optimization doesn't kick in, I don't understand the reasoning for not applying it +by default, except for the fact that in their case the optimization increases the output code size by duplicating +the function. + +3) The OCaml implementation chose to only support data construction, whereas we also support arithmetic operators, +stirng concatenation and list concatenation. + +4) The OCaml implementation chose to only support data construction. Since we can analyze all the functions and determine +whether they simply store an argument in an object or do something else with it, we can support more functions that are +not necessarily data constructors. In the following example a recursive call in position `x` could be optimized, but +not one in position `y`: `fn x y = { x = x, y = y + 1 }`. + +5) Because OCaml is an effectful language, the order of operations matters and the resulting code reflects that. +Since Elm isn't an effectful language, I took the liberty of re-ordering the computations to make for shorter output code. + + +# Remaining work + +The current work is ready to be tested and benchmarked, but there is more work to be done that would be valuable. + + +## Support tail-recursion that is nested in tail-preserving contexts like + +The following code could be optimized but isn't. + + map fn list = + case list of + [] -> [] + x :: xs -> + f x :: + (if condition then + map fn xs + else + map fn (y :: xs) + ) + + +## Support using both addition and multiplication and other operations, but only choose one (the most common?) + +Currently, our implementation doesn't apply any optimizations if it sees both addition and multiplication recursions. +We could do better by supporting at least one of them, leaving the other as a regular call. + +One issue with this is that we would need to rely on heuristics to choose which operation to optimize for (the ones that +we found the most) but that could be the wrong one to optimize for. In the compiler implementation and with new syntax, +annotations could possibly be added by users to choose which operation to optimize. + +In any case, since it reduces the chance for a stack overflow and likely improves performance, I think it would still be +worth it even if it's imperfect. + + +## Special casing list literals + +When concatenating lists, we make a copy of the first list. When encountering list literals which are +wrapped in a `_List_fromArray`, we could instead use a copy of the `_List_fromArray` function that makes the list not +end with `_List_Nil` but with the list to append to, with a function like this: + + function _List_fromArray_andAppend(arr, end) { + var out = end; + for (var i = arr.length; i--; ) + { + out = _List_Cons(arr[i], out); + } + return out; + } + + +## Figure out whether it would be interesting to check whether $tail is empty + +When we find appends to the recursion result, we prepend it to `$tail`, which means creating a copy. +When this is the first append though (`$tail` is empty`), then we are doing an unnecessary copy of the list, +which the unoptimized version would not do. + +A potential fix for that would be to check whether `$tail` is empty (or null, whatever is quicker) and if it is +skip the copy and just make `$tail` point to the list. The cost for this would be an extra check every time we wish +to add to the tail, and that might hurt performance. + + +## Support nested data constructors + +The OCaml implementation did this but this implementation hasn't yet. + + +## Support let declarations + +At the moment, this implementation only targets top-level functions, and not let functions. +It absolutely should. The only thing to be wary of is the potential naming conflicts for the variables that we introduce +when optimizing a let function inside of a function that is also being optimized. + + +## Support for recursions in lambda declarations + +When encountering code like this: + + naiveMap : (a -> b) -> List a -> List b + naiveMap = + \fn list -> + case list of + [] -> + [] + x :: xs -> + fn x :: naiveMap fn xs + +the compiler outputs something pretty complex with multiple declarations and re-assignements. +I don't think this is a high-priority and worth the effort, but I'll mention it anyway. + +*/ + +type Context = ts.TransformationContext; + +const LIST_CONS = "_List_Cons"; +const LIST_FROM_ARRAY = "_List_fromArray"; +const EMPTY_LIST = "_List_Nil"; +const UTILS_AP = "_Utils_ap"; +const COPY_LIST_AND_GET_END = "_Utils_copyListAndGetEnd"; +const LIST_APPEND = "$elm$core$List$append"; + +const newFunctionDefinitions: {[key: string]: string} = { + [COPY_LIST_AND_GET_END]: + `function _Utils_copyListAndGetEnd(root, xs) { + for (; xs.b; xs = xs.b) { + root = root.b = _List_Cons(xs.a, _List_Nil); + } + return root; + }`, +}; + +export const createTailCallRecursionTransformer = (forTests: boolean) => (context : Context) => { + return (sourceFile : ts.SourceFile) => { + const functionsToInsert: Set = new Set(); + + const visitor = (node: ts.Node): ts.VisitResult => { + // Is `var x = FX(function(...) { ... })` or `var x = function(...) { ... }` + if (ts.isVariableDeclaration(node) + && ts.isIdentifier(node.name) + && node.initializer + ) { + const foundFunction = findFunction(node.initializer); + if (!foundFunction) { + return ts.visitEachChild(node, visitor, context); + } + + const functionName = node.name.text; + const labelSplits = functionName.split("$"); + const label = labelSplits[labelSplits.length - 1] || functionName; + + const functionAnalysis : FunctionAnalysis = determineRecursionType(functionName, label, foundFunction.fn.body); + if (functionAnalysis.recursionType.kind === FunctionRecursionKind.F_NotRecursive + || (functionAnalysis.recursionType.kind === FunctionRecursionKind.F_ConcatRecursion && !functionAnalysis.recursionType.hasPlainRecursionCalls) + ) { + return node; + } + + const parameterNames : Array = foundFunction.fn.parameters.map(param => { + return ts.isIdentifier(param.name) ? param.name.text : ''; + }); + const newBody : ts.Block = updateFunctionBody( + functionAnalysis.extractCache, + functionsToInsert, + functionAnalysis.recursionType, + functionAnalysis.shouldAddWhileLoop, + functionName, + label, + parameterNames, + foundFunction.fn.body, + context + ); + + return ts.updateVariableDeclaration( + node, + node.name, + node.type, + foundFunction.update(newBody) + ); + } + return ts.visitEachChild(node, visitor, context); + }; + + return introduceFunctions(functionsToInsert, ts.visitNode(sourceFile, visitor), forTests, context); + }; +}; + +function findFunction(node : ts.Node) { + // Multiple-argument function wrapped in FX function + if (ts.isCallExpression(node)) { + var fn = extractFCall(node); + if (!fn) { + return null; + } + const {name, parameters, modifiers} = fn; + + return { + fn: fn, + update: (body : ts.Block) => { + const newFn = ts.createFunctionExpression( + modifiers, + undefined, + name, + undefined, + parameters, + undefined, + body + ); + + return ts.updateCall( + node, + node.expression, + node.typeArguments, + ts.createNodeArray([newFn]) + ); + } + } + } + + // Single-argument function not wrapped in FX + if (ts.isFunctionExpression(node)) { + return { + fn: node, + update: (body : ts.Block) => { + return ts.createFunctionExpression( + node.modifiers, + undefined, + node.name, + undefined, + node.parameters, + undefined, + body + ); + } + } + } + + return null; +} + +function extractFCall(node: ts.CallExpression): ts.FunctionExpression | null { + if (ts.isIdentifier(node.expression) + && node.expression.text.startsWith('F') + && node.arguments.length > 0 + ) { + const fn = node.arguments[0]; + if (ts.isFunctionExpression(fn)) { + return fn; + } + + return null; + } + + return null; +} + +enum RecursionTypeKind { + NotRecursive, + PlainRecursion, + BooleanRecursion, + ListOperationsRecursion, + DataConstructionRecursion, + MultipleDataConstructionRecursion, + AddRecursion, + StringConcatRecursion, + ConcatRecursion, + MultiplyRecursion, +}; + +enum BooleanKind { + And, + Or, +}; + +enum FunctionRecursionKind { + F_NotRecursive, + F_PlainRecursion, + F_BooleanRecursion, + F_ListRecursion, + F_DataConstructionRecursion, + F_MultipleDataConstructionRecursion, + F_AddRecursion, + F_StringConcatRecursion, + F_ConcatRecursion, + F_MultiplyRecursion, +}; + +type FunctionRecursion + = { kind: FunctionRecursionKind.F_PlainRecursion } + | ListRecursion + | { kind: FunctionRecursionKind.F_BooleanRecursion } + | { kind: FunctionRecursionKind.F_DataConstructionRecursion, property: string } + | { kind: FunctionRecursionKind.F_MultipleDataConstructionRecursion } + | { kind: FunctionRecursionKind.F_AddRecursion, numbersConfirmed : boolean, left: boolean, right: boolean } + | StringConcatRecursion + | UndeterminedConcatRecursion + | { kind: FunctionRecursionKind.F_MultiplyRecursion } + +type StringConcatRecursion = + { + kind: FunctionRecursionKind.F_StringConcatRecursion, + left: boolean, + right: boolean + } + +type UndeterminedConcatRecursion = + { + kind: FunctionRecursionKind.F_ConcatRecursion, + left: boolean, + right: boolean, + hasPlainRecursionCalls: boolean + } + +type ListRecursion = + { + kind: FunctionRecursionKind.F_ListRecursion, + left: boolean, + right: boolean + } + +type Recursion + = PlainRecursion + | ListOperationsRecursion + | BooleanRecursion + | DataConstructionRecursion + | MultipleDataConstructionRecursion + | AddRecursion + | MultiplyRecursion + | ConcatRecursion + +type NotRecursiveFunction = + { + kind: FunctionRecursionKind.F_NotRecursive + } + +type NotRecursive = + { + kind: RecursionTypeKind.NotRecursive + } + +type PlainRecursion = + { + kind: RecursionTypeKind.PlainRecursion, + arguments : Array + } + +type BooleanRecursion = + { + kind: RecursionTypeKind.BooleanRecursion, + expression: ts.Expression, + booleanKind: BooleanKind, + arguments : Array + } + +type DataConstructionRecursion = + { + kind: RecursionTypeKind.DataConstructionRecursion, + property: string, + expression : ts.Expression, + arguments : Array + } + +type MultipleDataConstructionRecursion = + { + kind: RecursionTypeKind.MultipleDataConstructionRecursion, + property: string, + expression : ts.Expression, + arguments : Array + } + +type AddRecursion = + { + kind: RecursionTypeKind.AddRecursion, + left : ts.Expression | null, + right : ts.Expression | null, + arguments : Array, + adds: "numbers" | "strings" | null + } + +type ConcatRecursion = + { + kind: RecursionTypeKind.ConcatRecursion, + left : ts.Expression | null, + right : ts.Expression | null, + arguments : Array, + concatenates: "strings" | "lists" | null + } + +type ListOperationsRecursion = + { + kind: RecursionTypeKind.ListOperationsRecursion, + left : ListOperation[], + right : ts.Expression | null, + arguments : Array + } + +type ListOperation = + { + kind : "cons" | "append", + expression : ts.Expression + } + +type MultiplyRecursion = + { + kind: RecursionTypeKind.MultiplyRecursion, + expression : ts.Expression, + arguments : Array + } + +type FunctionAnalysis = + { + recursionType : FunctionRecursion | NotRecursiveFunction, + shouldAddWhileLoop : boolean, + extractCache : ExtractCache + } + +function determineRecursionType(functionName : string, label : string, body : ts.Node) : FunctionAnalysis { + let extractCache : ExtractCache = new Map(); + let recursionType : FunctionRecursion | NotRecursiveFunction = { kind: FunctionRecursionKind.F_NotRecursive }; + let shouldAddWhileLoop : boolean = true; + let inferredType : PossibleReturnType = null; + const iter = findReturnStatements(label, body); + + while (!hasRecursionTypeBeenDetermined(recursionType)) { + const next = iter.next(); + if (next.done) { break; } + if (next.value === "has-while-loop") { + shouldAddWhileLoop = false; + continue; + } + + const node : ts.Expression = next.value; + const recursionForReturn : Recursion | NotRecursive = extractRecursionKindFromExpression(extractCache, functionName, node); + addToCache(extractCache, node, recursionForReturn); + + if (recursionForReturn.kind === RecursionTypeKind.NotRecursive) { + const refinement = refineTypeForExpression(recursionType, node, inferredType); + recursionType = refinement.recursionType; + inferredType = refinement.inferredType; + } + else { + recursionType = refineRecursionType(recursionType, inferredType, recursionForReturn); + } + } + + return { recursionType, shouldAddWhileLoop, extractCache }; +} + +function hasRecursionTypeBeenDetermined(recursion : FunctionRecursion | NotRecursiveFunction) { + switch (recursion.kind) { + case FunctionRecursionKind.F_NotRecursive: return false; + case FunctionRecursionKind.F_PlainRecursion: return false; + case FunctionRecursionKind.F_BooleanRecursion: return true; + case FunctionRecursionKind.F_MultiplyRecursion: return true; + case FunctionRecursionKind.F_MultipleDataConstructionRecursion: return true; + case FunctionRecursionKind.F_DataConstructionRecursion: return false; + case FunctionRecursionKind.F_AddRecursion: return recursion.numbersConfirmed; + case FunctionRecursionKind.F_ConcatRecursion: return false; + case FunctionRecursionKind.F_ListRecursion: { + // We need to know for sure on which side there will be concatenation. + return recursion.left === true && recursion.right === true; + } + case FunctionRecursionKind.F_StringConcatRecursion: + // We need to know for sure on which side there will be concatenation. + return recursion.left === true && recursion.right === true; + } +} + +function refineRecursionType( + recursionType : FunctionRecursion | NotRecursiveFunction, + inferredType : PossibleReturnType, + recursion : Recursion | NotRecursive +) : FunctionRecursion | NotRecursiveFunction { + switch (recursionType.kind) { + case FunctionRecursionKind.F_BooleanRecursion: + return recursionType; + + case FunctionRecursionKind.F_MultiplyRecursion: + return recursionType; + + case FunctionRecursionKind.F_MultipleDataConstructionRecursion: + return recursionType; + + case FunctionRecursionKind.F_NotRecursive: + return toFunctionRecursion(recursion, inferredType, false); + + case FunctionRecursionKind.F_PlainRecursion: + return toFunctionRecursion(recursion, inferredType, true); + + case FunctionRecursionKind.F_DataConstructionRecursion: { + if (recursion.kind === RecursionTypeKind.DataConstructionRecursion && recursion.property !== recursionType.property) { + return { kind: FunctionRecursionKind.F_MultipleDataConstructionRecursion }; + } + return recursionType; + } + + case FunctionRecursionKind.F_AddRecursion: { + switch (recursion.kind) { + case RecursionTypeKind.AddRecursion: { + if (recursion.adds === "strings" || inferredType === "strings" || inferredType === "strings-or-lists") { + return { + kind: FunctionRecursionKind.F_StringConcatRecursion, + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right + }; + } + return { + kind: FunctionRecursionKind.F_AddRecursion, + numbersConfirmed: recursion.adds === "numbers" || inferredType === "numbers", + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right + }; + } + case RecursionTypeKind.ConcatRecursion: { + return { + kind: FunctionRecursionKind.F_StringConcatRecursion, + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right + }; + } + default: { + return recursionType; + } + } + }; + + case FunctionRecursionKind.F_ListRecursion: { + switch (recursion.kind) { + case RecursionTypeKind.ConcatRecursion: { + return { + kind: FunctionRecursionKind.F_ListRecursion, + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right + }; + } + case RecursionTypeKind.ListOperationsRecursion: { + return { + kind: FunctionRecursionKind.F_ListRecursion, + left: recursionType.left || recursion.left.length > 0, + right: recursionType.right || !!recursion.right + }; + } + default: { + return recursionType; + } + } + } + + case FunctionRecursionKind.F_StringConcatRecursion: { + switch (recursion.kind) { + case RecursionTypeKind.AddRecursion: { + return { + kind: FunctionRecursionKind.F_StringConcatRecursion, + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right + }; + } + case RecursionTypeKind.ConcatRecursion: { + return { + kind: FunctionRecursionKind.F_StringConcatRecursion, + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right + }; + } + default: { + return recursionType; + } + } + } + + case FunctionRecursionKind.F_ConcatRecursion: { + switch (recursion.kind) { + case RecursionTypeKind.ConcatRecursion: { + if (recursion.concatenates === "strings") { + return { + kind: FunctionRecursionKind.F_StringConcatRecursion, + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right + }; + } + if (recursion.concatenates === "lists") { + return { + kind: FunctionRecursionKind.F_ListRecursion, + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right + }; + } + return { + kind: FunctionRecursionKind.F_ConcatRecursion, + left: recursionType.left || !!recursion.left, + right: recursionType.right || !!recursion.right, + hasPlainRecursionCalls: recursionType.hasPlainRecursionCalls + }; + } + case RecursionTypeKind.ListOperationsRecursion: { + return { + kind: FunctionRecursionKind.F_ListRecursion, + left: recursionType.left || recursion.left.length > 0, + right: recursionType.right || !!recursion.right + }; + } + case RecursionTypeKind.PlainRecursion: { + return { + kind: FunctionRecursionKind.F_ConcatRecursion, + left: recursionType.left, + right: recursionType.right, + hasPlainRecursionCalls: true + }; + } + default: { + return recursionType; + } + } + } + } +} + +function refineTypeForExpression( + recursionType : FunctionRecursion | NotRecursiveFunction, + node : ts.Expression, + inferredType : PossibleReturnType +) : { recursionType : FunctionRecursion | NotRecursiveFunction, inferredType : PossibleReturnType } { + switch (recursionType.kind) { + case FunctionRecursionKind.F_BooleanRecursion: + case FunctionRecursionKind.F_MultiplyRecursion: + case FunctionRecursionKind.F_MultipleDataConstructionRecursion: + case FunctionRecursionKind.F_ListRecursion: + case FunctionRecursionKind.F_DataConstructionRecursion: + case FunctionRecursionKind.F_StringConcatRecursion: { + return {recursionType, inferredType }; + } + + case FunctionRecursionKind.F_NotRecursive: + case FunctionRecursionKind.F_PlainRecursion: { + return { + recursionType, + inferredType: determineType(node, inferredType) + }; + } + + case FunctionRecursionKind.F_AddRecursion: { + inferredType = determineType(node, inferredType); + if (inferredType === "strings" || inferredType === "strings-or-lists") { + recursionType = { + kind: FunctionRecursionKind.F_StringConcatRecursion, + left: recursionType.left, + right: recursionType.right + }; + } + else if (inferredType === "numbers") { + recursionType = { + kind: FunctionRecursionKind.F_AddRecursion, + numbersConfirmed: true, + left: recursionType.left, + right: recursionType.right + }; + } + return { recursionType, inferredType }; + } + + case FunctionRecursionKind.F_ConcatRecursion: { + inferredType = determineType(node, inferredType); + if (inferredType === "strings" || inferredType === "numbers-or-strings") { + recursionType = { + kind: FunctionRecursionKind.F_StringConcatRecursion, + left: recursionType.left, + right: recursionType.right + }; + } + else if (inferredType === "lists") { + recursionType = { + kind: FunctionRecursionKind.F_ListRecursion, + left: recursionType.left, + right: recursionType.right + }; + } + return { recursionType, inferredType }; + } + } +} + +function* findReturnStatements(label : string, body : ts.Node) : Generator { + let nodesToVisit : Array = [body]; + let node : ts.Node | undefined; + + loop: while (node = nodesToVisit.shift()) { + if (ts.isParenthesizedExpression(node)) { + nodesToVisit = [node.expression, ...nodesToVisit]; + continue loop; + } + + if (ts.isBlock(node)) { + nodesToVisit = [...node.statements, ...nodesToVisit]; + continue loop; + } + + if (ts.isLabeledStatement(node)) { + nodesToVisit.unshift(node.statement); + if (node.label.text === label) { + yield "has-while-loop"; + } + continue loop; + } + + if (ts.isWhileStatement(node)) { + nodesToVisit.unshift(node.statement); + continue loop; + } + + if (ts.isIfStatement(node)) { + if (node.elseStatement) { + nodesToVisit.unshift(node.elseStatement); + } + nodesToVisit.unshift(node.thenStatement); + continue loop; + } + + if (ts.isSwitchStatement(node)) { + nodesToVisit = [ + ...node.caseBlock.clauses.flatMap(clause => [...clause.statements]), + ...nodesToVisit + ]; + continue loop; + } + + if (ts.isReturnStatement(node) && node.expression) { + if (ts.isConditionalExpression(node.expression)) { + nodesToVisit = [ + ts.createReturn(node.expression.whenTrue), + ts.createReturn(node.expression.whenFalse), + ...nodesToVisit + ]; + continue loop; + } + + if (ts.isParenthesizedExpression(node.expression)) { + nodesToVisit.unshift(ts.createReturn(node.expression.expression)); + continue loop; + } + + yield node.expression; + continue loop; + } + } +} + +const START = ts.createIdentifier("$start"); +const END = ts.createIdentifier("$end"); +const TAIL = ts.createIdentifier("$tail"); +const FIELD = ts.createIdentifier("$field"); +const RESULT = ts.createIdentifier("$result"); +const LEFT = ts.createIdentifier("$left"); +const RIGHT = ts.createIdentifier("$right"); + +function consDeclarations(left : boolean, right : boolean) { + return [ + ...(left ? consLeftDeclarations : []), + ...(right ? [consrightDeclaration] : []) + ]; +} + +const consLeftDeclarations = + [ + // `var $start = _List_Cons(undefined, _List_Nil);` + ts.createVariableStatement( + undefined, + [ts.createVariableDeclaration( + START, + undefined, + consToList( + ts.createIdentifier("undefined"), + ts.createIdentifier(EMPTY_LIST) + ) + )] + ), + // `var $end = $start;` + ts.createVariableStatement( + undefined, + [ ts.createVariableDeclaration( + END, + undefined, + START + ) + ] + ) + ]; + +// `var $tail = _List_Nil;` +const consrightDeclaration = + ts.createVariableStatement( + undefined, + [ts.createVariableDeclaration( + TAIL, + undefined, + ts.createIdentifier(EMPTY_LIST) + )] + ); + +const multipleConstructorDeclarations = +[ + // `var $start = { a: null };` + ts.createVariableStatement( + undefined, + [ts.createVariableDeclaration( + START, + undefined, + ts.createObjectLiteral([ + ts.createPropertyAssignment("a", ts.createNull()) + ]) + )] + ), + // `var $end = $start;` + ts.createVariableStatement( + undefined, + [ ts.createVariableDeclaration( + END, + undefined, + START + ) + ] + ), + // `var $field = 'a';` + ts.createVariableStatement( + undefined, + [ ts.createVariableDeclaration( + FIELD, + undefined, + ts.createLiteral('a') + ) + ] + ) +]; + +// `var $result = ;` +function resultDeclaration(n : number) { + return ts.createVariableStatement( + undefined, + [ts.createVariableDeclaration( + RESULT, + undefined, + ts.createLiteral(n) + )] + ); +} + +function stringConsDeclaration(left : boolean, right: boolean) { + let declarations = []; + if (left) { + declarations.push( + //`$left = ""` + ts.createVariableDeclaration( + LEFT, + undefined, + ts.createStringLiteral("") + ) + ); + } + + if (right) { + declarations.push( + //`$right = ""` + ts.createVariableDeclaration( + RIGHT, + undefined, + ts.createStringLiteral("") + ) + ); + } + + return ts.createVariableStatement(undefined, declarations); +} + +function constructorDeclarations(property : string) { + return [ + // `var $start = { : null };` + ts.createVariableStatement( + undefined, + [ts.createVariableDeclaration( + START, + undefined, + ts.createObjectLiteral([ + ts.createPropertyAssignment(property, ts.createNull()) + ]) + )] + ), + // `var $end = $start;` + ts.createVariableStatement( + undefined, + [ ts.createVariableDeclaration( + END, + undefined, + START + ) + ] + ) + ]; +} + +function updateFunctionBody(extractCache : ExtractCache, functionsToInsert : Set, recursionType : FunctionRecursion, shouldAddWhileLoop : boolean, functionName : string, label : string, parameterNames : Array, body : ts.Block, context : Context) : ts.Block { + const updatedBlock = ts.visitEachChild(body, updateRecursiveCallVisitor, context); + + function updateRecursiveCallVisitor(node: ts.Node): ts.VisitResult { + if (ts.isBlock(node) + || ts.isLabeledStatement(node) + || ts.isWhileStatement(node) + || ts.isSwitchStatement(node) + || ts.isCaseClause(node) + || ts.isCaseBlock(node) + || ts.isDefaultClause(node) + ) { + return ts.visitEachChild(node, updateRecursiveCallVisitor, context); + } + + if (ts.isIfStatement(node)) { + return ts.updateIf( + node, + node.expression, + ts.visitNode(node.thenStatement, updateRecursiveCallVisitor), + ts.visitNode(node.elseStatement, updateRecursiveCallVisitor) + ) + } + + if (ts.isReturnStatement(node) + && node.expression + ) { + return updateReturnStatement(extractCache, functionsToInsert, recursionType, functionName, label, parameterNames, node.expression) || node; + } + + return node; + } + + const declarations = declarationsForRecursiveFunction(recursionType); + + if (shouldAddWhileLoop) { + return ts.createBlock( + [ + ...declarations, + // `