Skip to content

Improve inferring generic parameters #839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,37 @@ export abstract class TypeNode extends Node {

/** Whether nullable or not. */
isNullable: bool;

/** Tests if this type has a generic component matching one of the given type parameters. */
hasGenericComponent(typeParameterNodes: TypeParameterNode[]): bool {
var self = <TypeNode>this; // TS otherwise complains
if (this.kind == NodeKind.NAMEDTYPE) {
if (!(<NamedTypeNode>self).name.next) {
let typeArgumentNodes = (<NamedTypeNode>self).typeArguments;
if (typeArgumentNodes !== null && typeArgumentNodes.length) {
for (let i = 0, k = typeArgumentNodes.length; i < k; ++i) {
if (typeArgumentNodes[i].hasGenericComponent(typeParameterNodes)) return true;
}
} else {
let name = (<NamedTypeNode>self).name.identifier.text;
for (let i = 0, k = typeParameterNodes.length; i < k; ++i) {
if (typeParameterNodes[i].name.text == name) return true;
}
}
}
} else if (this.kind == NodeKind.FUNCTIONTYPE) {
let parameterNodes = (<FunctionTypeNode>self).parameters;
for (let i = 0, k = parameterNodes.length; i < k; ++i) {
if (parameterNodes[i].type.hasGenericComponent(typeParameterNodes)) return true;
}
if ((<FunctionTypeNode>self).returnType.hasGenericComponent(typeParameterNodes)) return true;
let explicitThisType = (<FunctionTypeNode>self).explicitThisType;
if (explicitThisType !== null && explicitThisType.hasGenericComponent(typeParameterNodes)) return true;
} else {
assert(false);
}
return false;
}
}

/** Represents a type name. */
Expand Down
70 changes: 20 additions & 50 deletions src/compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5843,70 +5843,46 @@ export class Compiler extends DiagnosticEmitter {

// infer generic call if type arguments have been omitted
} else if (prototype.is(CommonFlags.GENERIC)) {
let inferredTypes = new Map<string,Type | null>();
let contextualTypeArguments = makeMap<string,Type>(flow.contextualTypeArguments);

// fill up contextual types with auto for each generic component
let typeParameterNodes = assert(prototype.typeParameterNodes);
let numTypeParameters = typeParameterNodes.length;
let typeParameterNames = new Set<string>();
for (let i = 0; i < numTypeParameters; ++i) {
inferredTypes.set(typeParameterNodes[i].name.text, null);
let name = typeParameterNodes[i].name.text;
contextualTypeArguments.set(name, Type.auto);
typeParameterNames.add(name);
}
// let numInferred = 0;

let parameterNodes = prototype.functionTypeNode.parameters;
let numParameters = parameterNodes.length;
let argumentNodes = expression.arguments;
let numArguments = argumentNodes.length;
let argumentExprs = new Array<ExpressionRef>(numArguments);

// infer types with generic components while updating contextual types
for (let i = 0; i < numParameters; ++i) {
let typeNode = parameterNodes[i].type;
let templateName = typeNode.kind == NodeKind.NAMEDTYPE && !(<NamedTypeNode>typeNode).name.next
? (<NamedTypeNode>typeNode).name.identifier.text
: null;
let argumentExpression = i < numArguments
? argumentNodes[i]
: parameterNodes[i].initializer;
let argumentExpression = i < numArguments ? argumentNodes[i] : parameterNodes[i].initializer;
if (!argumentExpression) { // missing initializer -> too few arguments
this.error(
DiagnosticCode.Expected_0_arguments_but_got_1,
expression.range, numParameters.toString(10), numArguments.toString(10)
);
return module.unreachable();
}
if (templateName !== null && inferredTypes.has(templateName)) {
let inferredType = inferredTypes.get(templateName);
if (inferredType) {
argumentExprs[i] = this.compileExpression(argumentExpression, inferredType);
let commonType: Type | null;
if (!(commonType = Type.commonDenominator(inferredType, this.currentType, true))) {
if (!(commonType = Type.commonDenominator(inferredType, this.currentType, false))) {
this.error(
DiagnosticCode.Type_0_is_not_assignable_to_type_1,
parameterNodes[i].type.range, this.currentType.toString(), inferredType.toString()
);
return module.unreachable();
}
}
inferredType = commonType;
} else {
argumentExprs[i] = this.compileExpression(argumentExpression, Type.auto);
inferredType = this.currentType;
// ++numInferred;
}
inferredTypes.set(templateName, inferredType);
} else {
let concreteType = this.resolver.resolveType(
parameterNodes[i].type,
flow.actualFunction,
flow.contextualTypeArguments
);
if (!concreteType) return module.unreachable();
argumentExprs[i] = this.compileExpression(argumentExpression, concreteType, Constraints.CONV_IMPLICIT);
let typeNode = parameterNodes[i].type;
if (typeNode.hasGenericComponent(typeParameterNodes)) {
this.resolver.inferGenericType(typeNode, argumentExpression, flow, contextualTypeArguments, typeParameterNames);
}
}
let resolvedTypeArguments = new Array<Type>(numTypeParameters);

// apply concrete types to the generic function signature
let resolvedTypeArguments = Array.create<Type>(numTypeParameters);
for (let i = 0; i < numTypeParameters; ++i) {
let name = typeParameterNodes[i].name.text;
if (inferredTypes.has(name)) {
let inferredType = inferredTypes.get(name);
if (inferredType) {
if (contextualTypeArguments.has(name)) {
let inferredType = contextualTypeArguments.get(name)!;
if (inferredType != Type.auto) {
resolvedTypeArguments[i] = inferredType;
continue;
}
Expand All @@ -5924,12 +5900,6 @@ export class Compiler extends DiagnosticEmitter {
resolvedTypeArguments,
makeMap<string,Type>(flow.contextualTypeArguments)
);
if (!instance) return this.module.unreachable();
if (prototype.hasDecorator(DecoratorFlags.UNSAFE)) this.checkUnsafe(expression);
return this.makeCallDirect(instance, argumentExprs, expression, contextualType == Type.void);
// TODO: this skips inlining because inlining requires compiling its temporary locals in
// the scope of the inlined flow. might need another mechanism to lock temp. locals early,
// so inlining can be performed in `makeCallDirect` instead?

// otherwise resolve the non-generic call as usual
} else {
Expand Down
93 changes: 91 additions & 2 deletions src/resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ import {
TernaryExpression,
isTypeOmitted,
FunctionExpression,
NewExpression
NewExpression,
ParameterNode
} from "./ast";

import {
Expand Down Expand Up @@ -691,6 +692,82 @@ export class Resolver extends DiagnosticEmitter {
return typeArguments;
}

/** Infers the generic type(s) of an argument expression and updates `ctxTypes`. */
inferGenericType(
/** The generic type being inferred. */
typeNode: TypeNode,
/** The respective argument expression. */
exprNode: Expression,
/** Contextual flow. */
ctxFlow: Flow,
/** Contextual types, i.e. `T`, with unknown types initialized to `auto`. */
ctxTypes: Map<string,Type>,
/** The names of the type parameters being inferred. */
typeParameterNames: Set<string>
): void {
var type = this.resolveExpression(exprNode, ctxFlow, Type.auto, ReportMode.SWALLOW);
if (type) this.propagateInferredGenericTypes(typeNode, type, ctxFlow, ctxTypes, typeParameterNames);
}

/** Updates contextual types with a possibly encapsulated inferred type. */
private propagateInferredGenericTypes(
/** The inferred type node. */
node: TypeNode,
/** The inferred type. */
type: Type,
/** Contextual flow. */
ctxFlow: Flow,
/** Contextual types, i.e. `T`, with unknown types initialized to `auto`. */
ctxTypes: Map<string,Type>,
/** The names of the type parameters being inferred. */
typeParameterNames: Set<string>
): void {
if (node.kind == NodeKind.NAMEDTYPE) {
let typeArgumentNodes = (<NamedTypeNode>node).typeArguments;
if (typeArgumentNodes !== null && typeArgumentNodes.length) { // foo<T>(bar: Array<T>)
let classReference = type.classReference;
if (classReference) {
let classPrototype = this.resolveTypeName((<NamedTypeNode>node).name, ctxFlow.actualFunction);
if (!classPrototype || classPrototype.kind != ElementKind.CLASS_PROTOTYPE) return;
if (classReference.prototype == <ClassPrototype>classPrototype) {
let typeArguments = classReference.typeArguments;
if (typeArguments !== null && typeArguments.length == typeArgumentNodes.length) {
for (let i = 0, k = typeArguments.length; i < k; ++i) {
this.propagateInferredGenericTypes(typeArgumentNodes[i], typeArguments[i], ctxFlow, ctxTypes, typeParameterNames);
}
return;
}
}
}
} else { // foo<T>(bar: T)
let name = (<NamedTypeNode>node).name.identifier.text;
if (ctxTypes.has(name)) {
let currentType = ctxTypes.get(name)!;
if (currentType == Type.auto || (typeParameterNames.has(name) && currentType.isAssignableTo(type))) {
ctxTypes.set(name, type);
}
}
}
} else if (node.kind == NodeKind.FUNCTIONTYPE) { // foo<T>(bar: (baz: T) => i32))
let parameterNodes = (<FunctionTypeNode>node).parameters;
if (parameterNodes !== null && parameterNodes.length) {
let signatureReference = type.signatureReference;
if (signatureReference) {
let parameterTypes = signatureReference.parameterTypes;
let thisType = signatureReference.thisType;
if (parameterTypes.length == parameterNodes.length && !thisType == !(<FunctionTypeNode>node).explicitThisType) {
for (let i = 0, k = parameterTypes.length; i < k; ++i) {
this.propagateInferredGenericTypes(parameterNodes[i].type, parameterTypes[i], ctxFlow, ctxTypes, typeParameterNames);
}
this.propagateInferredGenericTypes((<FunctionTypeNode>node).returnType, signatureReference.returnType, ctxFlow, ctxTypes, typeParameterNames);
if (thisType) this.propagateInferredGenericTypes((<FunctionTypeNode>node).explicitThisType!, thisType, ctxFlow, ctxTypes, typeParameterNames);
return;
}
}
}
}
}

/** Gets the concrete type of an element. */
getTypeOfElement(element: Element): Type | null {
var kind = element.kind;
Expand Down Expand Up @@ -908,7 +985,7 @@ export class Resolver extends DiagnosticEmitter {
case NodeKind.TRUE: {
return this.resolveIdentifierExpression(
<IdentifierExpression>node,
ctxFlow, ctxFlow.actualFunction, reportMode
ctxFlow, ctxType, ctxFlow.actualFunction, reportMode
);
}
case NodeKind.THIS: {
Expand Down Expand Up @@ -1018,11 +1095,23 @@ export class Resolver extends DiagnosticEmitter {
node: IdentifierExpression,
/** Flow to search for scoped locals. */
ctxFlow: Flow,
/** Contextual type. */
ctxType: Type = Type.auto,
/** Element to search. */
ctxElement: Element = ctxFlow.actualFunction, // differs for enums and namespaces
/** How to proceed with eventual diagnostics. */
reportMode: ReportMode = ReportMode.REPORT
): Type | null {
switch (node.kind) {
case NodeKind.TRUE:
case NodeKind.FALSE: return Type.bool;
case NodeKind.NULL: {
let classReference = ctxType.classReference;
return ctxType.is(TypeFlags.REFERENCE) && classReference !== null
? classReference.type.asNullable()
: this.program.options.usizeType; // TODO: anyref context?
}
}
var element = this.lookupIdentifierExpression(node, ctxFlow, ctxElement, reportMode);
if (!element) return null;
if (element.kind == ElementKind.FUNCTION_PROTOTYPE) {
Expand Down
36 changes: 35 additions & 1 deletion tests/compiler/call-inferred.optimized.wat
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
(module
(type $FUNCSIG$viiii (func (param i32 i32 i32 i32)))
(type $FUNCSIG$v (func))
(import "env" "abort" (func $~lib/builtins/abort (param i32 i32 i32 i32)))
(memory $0 1)
(data (i32.const 8) " \00\00\00\01\00\00\00\01\00\00\00 \00\00\00c\00a\00l\00l\00-\00i\00n\00f\00e\00r\00r\00e\00d\00.\00t\00s")
(global $~lib/argc (mut i32) (i32.const 0))
(export "memory" (memory $0))
(func $start (; 0 ;) (type $FUNCSIG$v)
(start $start)
(func $start:call-inferred (; 1 ;) (type $FUNCSIG$v)
(local $0 f32)
i32.const 0
global.set $~lib/argc
block $1of1
block $0of1
block $outOfRange
global.get $~lib/argc
br_table $0of1 $1of1 $outOfRange
end
unreachable
end
f32.const 42
local.set $0
end
local.get $0
f32.const 42
f32.ne
if
i32.const 0
i32.const 24
i32.const 13
i32.const 0
call $~lib/builtins/abort
unreachable
end
)
(func $start (; 2 ;) (type $FUNCSIG$v)
call $start:call-inferred
)
(func $null (; 3 ;) (type $FUNCSIG$v)
nop
)
)
28 changes: 23 additions & 5 deletions tests/compiler/call-inferred.untouched.wat
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
(data (i32.const 8) " \00\00\00\01\00\00\00\01\00\00\00 \00\00\00c\00a\00l\00l\00-\00i\00n\00f\00e\00r\00r\00e\00d\00.\00t\00s\00")
(table $0 1 funcref)
(elem (i32.const 0) $null)
(global $~lib/argc (mut i32) (i32.const 0))
(export "memory" (memory $0))
(start $start)
(func $call-inferred/foo<i32> (; 1 ;) (type $FUNCSIG$ii) (param $0 i32) (result i32)
Expand All @@ -23,7 +24,22 @@
(func $call-inferred/bar<f32> (; 4 ;) (type $FUNCSIG$ff) (param $0 f32) (result f32)
local.get $0
)
(func $start:call-inferred (; 5 ;) (type $FUNCSIG$v)
(func $call-inferred/bar<f32>|trampoline (; 5 ;) (type $FUNCSIG$ff) (param $0 f32) (result f32)
block $1of1
block $0of1
block $outOfRange
global.get $~lib/argc
br_table $0of1 $1of1 $outOfRange
end
unreachable
end
f32.const 42
local.set $0
end
local.get $0
call $call-inferred/bar<f32>
)
(func $start:call-inferred (; 6 ;) (type $FUNCSIG$v)
i32.const 42
call $call-inferred/foo<i32>
i32.const 42
Expand Down Expand Up @@ -63,8 +79,10 @@
call $~lib/builtins/abort
unreachable
end
f32.const 42
call $call-inferred/bar<f32>
i32.const 0
global.set $~lib/argc
f32.const 0
call $call-inferred/bar<f32>|trampoline
f32.const 42
f32.eq
i32.eqz
Expand All @@ -77,9 +95,9 @@
unreachable
end
)
(func $start (; 6 ;) (type $FUNCSIG$v)
(func $start (; 7 ;) (type $FUNCSIG$v)
call $start:call-inferred
)
(func $null (; 7 ;) (type $FUNCSIG$v)
(func $null (; 8 ;) (type $FUNCSIG$v)
)
)
5 changes: 5 additions & 0 deletions tests/compiler/infer-generic.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"asc_flags": [
"--runtime none"
]
}
Loading