Skip to content

Do precise subtype tests in instanceof #2588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Dec 14, 2022
42 changes: 21 additions & 21 deletions src/bindings/js.ts
Original file line number Diff line number Diff line change
Expand Up @@ -950,29 +950,29 @@ export class JSBuilder extends ExportsWalker {
if (type.isInternalReference) {
// Lift reference types
const clazz = assert(type.getClassOrWrapper(this.program));
if (clazz.extends(this.program.arrayBufferInstance.prototype)) {
if (clazz.extendsPrototype(this.program.arrayBufferInstance.prototype)) {
sb.push("__liftBuffer(");
this.needsLiftBuffer = true;
} else if (clazz.extends(this.program.stringInstance.prototype)) {
} else if (clazz.extendsPrototype(this.program.stringInstance.prototype)) {
sb.push("__liftString(");
this.needsLiftString = true;
} else if (clazz.extends(this.program.arrayPrototype)) {
} else if (clazz.extendsPrototype(this.program.arrayPrototype)) {
let valueType = clazz.getArrayValueType();
sb.push("__liftArray(");
this.makeLiftFromMemory(valueType, sb);
sb.push(", ");
sb.push(valueType.alignLog2.toString());
sb.push(", ");
this.needsLiftArray = true;
} else if (clazz.extends(this.program.staticArrayPrototype)) {
} else if (clazz.extendsPrototype(this.program.staticArrayPrototype)) {
let valueType = clazz.getArrayValueType();
sb.push("__liftStaticArray(");
this.makeLiftFromMemory(valueType, sb);
sb.push(", ");
sb.push(valueType.alignLog2.toString());
sb.push(", ");
this.needsLiftStaticArray = true;
} else if (clazz.extends(this.program.arrayBufferViewInstance.prototype)) {
} else if (clazz.extendsPrototype(this.program.arrayBufferViewInstance.prototype)) {
sb.push("__liftTypedArray(");
if (clazz.name == "Uint64Array") {
sb.push("BigUint64Array");
Expand Down Expand Up @@ -1021,13 +1021,13 @@ export class JSBuilder extends ExportsWalker {
if (type.isInternalReference) {
// Lower reference types
const clazz = assert(type.getClassOrWrapper(this.program));
if (clazz.extends(this.program.arrayBufferInstance.prototype)) {
if (clazz.extendsPrototype(this.program.arrayBufferInstance.prototype)) {
sb.push("__lowerBuffer(");
this.needsLowerBuffer = true;
} else if (clazz.extends(this.program.stringInstance.prototype)) {
} else if (clazz.extendsPrototype(this.program.stringInstance.prototype)) {
sb.push("__lowerString(");
this.needsLowerString = true;
} else if (clazz.extends(this.program.arrayPrototype)) {
} else if (clazz.extendsPrototype(this.program.arrayPrototype)) {
let valueType = clazz.getArrayValueType();
sb.push("__lowerArray(");
this.makeLowerToMemory(valueType, sb);
Expand All @@ -1037,7 +1037,7 @@ export class JSBuilder extends ExportsWalker {
sb.push(clazz.getArrayValueType().alignLog2.toString());
sb.push(", ");
this.needsLowerArray = true;
} else if (clazz.extends(this.program.staticArrayPrototype)) {
} else if (clazz.extendsPrototype(this.program.staticArrayPrototype)) {
let valueType = clazz.getArrayValueType();
sb.push("__lowerStaticArray(");
this.makeLowerToMemory(valueType, sb);
Expand All @@ -1047,7 +1047,7 @@ export class JSBuilder extends ExportsWalker {
sb.push(valueType.alignLog2.toString());
sb.push(", ");
this.needsLowerStaticArray = true;
} else if (clazz.extends(this.program.arrayBufferViewInstance.prototype)) {
} else if (clazz.extendsPrototype(this.program.arrayBufferViewInstance.prototype)) {
let valueType = clazz.getArrayValueType();
sb.push("__lowerTypedArray(");
if (valueType == Type.u64) {
Expand Down Expand Up @@ -1079,7 +1079,7 @@ export class JSBuilder extends ExportsWalker {
this.needsLowerInternref = true;
}
sb.push(name);
if (clazz.extends(this.program.staticArrayPrototype)) {
if (clazz.extendsPrototype(this.program.staticArrayPrototype)) {
// optional last argument for __lowerStaticArray
let valueType = clazz.getArrayValueType();
if (valueType.isNumericValue) {
Expand Down Expand Up @@ -1397,16 +1397,16 @@ export function liftRequiresExportRuntime(type: Type): bool {
let program = clazz.program;
// flat collections lift via memory copy
if (
clazz.extends(program.arrayBufferInstance.prototype) ||
clazz.extends(program.stringInstance.prototype) ||
clazz.extends(program.arrayBufferViewInstance.prototype)
clazz.extendsPrototype(program.arrayBufferInstance.prototype) ||
clazz.extendsPrototype(program.stringInstance.prototype) ||
clazz.extendsPrototype(program.arrayBufferViewInstance.prototype)
) {
return false;
}
// nested collections lift depending on element type
if (
clazz.extends(program.arrayPrototype) ||
clazz.extends(program.staticArrayPrototype)
clazz.extendsPrototype(program.arrayPrototype) ||
clazz.extendsPrototype(program.staticArrayPrototype)
) {
return liftRequiresExportRuntime(clazz.getArrayValueType());
}
Expand All @@ -1428,11 +1428,11 @@ export function lowerRequiresExportRuntime(type: Type): bool {
// lowers using __new
let program = clazz.program;
if (
clazz.extends(program.arrayBufferInstance.prototype) ||
clazz.extends(program.stringInstance.prototype) ||
clazz.extends(program.arrayBufferViewInstance.prototype) ||
clazz.extends(program.arrayPrototype) ||
clazz.extends(program.staticArrayPrototype)
clazz.extendsPrototype(program.arrayBufferInstance.prototype) ||
clazz.extendsPrototype(program.stringInstance.prototype) ||
clazz.extendsPrototype(program.arrayBufferViewInstance.prototype) ||
clazz.extendsPrototype(program.arrayPrototype) ||
clazz.extendsPrototype(program.staticArrayPrototype)
) {
return true;
}
Expand Down
12 changes: 6 additions & 6 deletions src/bindings/tsd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,26 +249,26 @@ export class TSDBuilder extends ExportsWalker {
if (type.isInternalReference) {
const sb = new Array<string>();
const clazz = assert(type.getClassOrWrapper(this.program));
if (clazz.extends(this.program.arrayBufferInstance.prototype)) {
if (clazz.extendsPrototype(this.program.arrayBufferInstance.prototype)) {
sb.push("ArrayBuffer");
} else if (clazz.extends(this.program.stringInstance.prototype)) {
} else if (clazz.extendsPrototype(this.program.stringInstance.prototype)) {
sb.push("string");
} else if (clazz.extends(this.program.arrayPrototype)) {
} else if (clazz.extendsPrototype(this.program.arrayPrototype)) {
const valueType = clazz.getArrayValueType();
sb.push("Array<");
sb.push(this.toTypeScriptType(valueType, mode));
sb.push(">");
} else if (clazz.extends(this.program.staticArrayPrototype)) {
} else if (clazz.extendsPrototype(this.program.staticArrayPrototype)) {
const valueType = clazz.getArrayValueType();
sb.push("ArrayLike<");
sb.push(this.toTypeScriptType(valueType, mode));
sb.push(">");
} else if (clazz.extends(this.program.arrayBufferViewInstance.prototype)) {
} else if (clazz.extendsPrototype(this.program.arrayBufferViewInstance.prototype)) {
const valueType = clazz.getArrayValueType();
if (valueType == Type.i8) {
sb.push("Int8Array");
} else if (valueType == Type.u8) {
if (clazz.extends(this.program.uint8ClampedArrayPrototype)) {
if (clazz.extendsPrototype(this.program.uint8ClampedArrayPrototype)) {
sb.push("Uint8ClampedArray");
} else {
sb.push("Uint8Array");
Expand Down
12 changes: 6 additions & 6 deletions src/builtins.ts
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ function builtin_isArray(ctx: BuiltinContext): ExpressionRef {
let classReference = type.getClass();
return reifyConstantType(ctx,
module.i32(
classReference && classReference.extends(compiler.program.arrayPrototype)
classReference && classReference.extendsPrototype(compiler.program.arrayPrototype)
? 1
: 0
)
Expand Down Expand Up @@ -10420,26 +10420,26 @@ export function compileRTTI(compiler: Compiler): void {
assert(instanceId == lastId++);
let flags: TypeinfoFlags = 0;
if (instance.isPointerfree) flags |= TypeinfoFlags.POINTERFREE;
if (instance != abvInstance && instance.extends(abvPrototype)) {
if (instance != abvInstance && instance.extendsPrototype(abvPrototype)) {
let valueType = instance.getArrayValueType();
flags |= TypeinfoFlags.ARRAYBUFFERVIEW;
flags |= TypeinfoFlags.VALUE_ALIGN_0 * typeToRuntimeFlags(valueType);
} else if (instance.extends(arrayPrototype)) {
} else if (instance.extendsPrototype(arrayPrototype)) {
let valueType = instance.getArrayValueType();
flags |= TypeinfoFlags.ARRAY;
flags |= TypeinfoFlags.VALUE_ALIGN_0 * typeToRuntimeFlags(valueType);
} else if (instance.extends(setPrototype)) {
} else if (instance.extendsPrototype(setPrototype)) {
let typeArguments = assert(instance.getTypeArgumentsTo(setPrototype));
assert(typeArguments.length == 1);
flags |= TypeinfoFlags.SET;
flags |= TypeinfoFlags.VALUE_ALIGN_0 * typeToRuntimeFlags(typeArguments[0]);
} else if (instance.extends(mapPrototype)) {
} else if (instance.extendsPrototype(mapPrototype)) {
let typeArguments = assert(instance.getTypeArgumentsTo(mapPrototype));
assert(typeArguments.length == 2);
flags |= TypeinfoFlags.MAP;
flags |= TypeinfoFlags.KEY_ALIGN_0 * typeToRuntimeFlags(typeArguments[0]);
flags |= TypeinfoFlags.VALUE_ALIGN_0 * typeToRuntimeFlags(typeArguments[1]);
} else if (instance.extends(staticArrayPrototype)) {
} else if (instance.extendsPrototype(staticArrayPrototype)) {
let valueType = instance.getArrayValueType();
flags |= TypeinfoFlags.STATICARRAY;
flags |= TypeinfoFlags.VALUE_ALIGN_0 * typeToRuntimeFlags(valueType);
Expand Down
114 changes: 78 additions & 36 deletions src/compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -556,12 +556,18 @@ export class Compiler extends DiagnosticEmitter {
for (let _keys = Map_keys(this.pendingInstanceOf), i = 0, k = _keys.length; i < k; ++i) {
let elem = _keys[i];
let name = assert(this.pendingInstanceOf.get(elem));
if (elem.kind == ElementKind.Class) {
this.finalizeInstanceOf(<Class>elem, name);
} else if (elem.kind == ElementKind.ClassPrototype) {
this.finalizeAnyInstanceOf(<ClassPrototype>elem, name);
} else {
assert(false);
switch (elem.kind) {
case ElementKind.Class:
case ElementKind.Interface: {
this.finalizeInstanceOf(<Class>elem, name);
break;
}
case ElementKind.ClassPrototype:
case ElementKind.InterfacePrototype: {
this.finalizeAnyInstanceOf(<ClassPrototype>elem, name);
break;
}
default: assert(false);
}
}

Expand Down Expand Up @@ -6569,19 +6575,25 @@ export class Compiler extends DiagnosticEmitter {
}
let classInstance = assert(overrideInstance.getBoundClassOrInterface());
builder.addCase(classInstance.id, stmts);
// Also alias each extendee inheriting this exact overload
let extendees = classInstance.getAllExtendees(instance.declaration.name.text); // without get:/set:
for (let _values = Set_values(extendees), a = 0, b = _values.length; a < b; ++a) {
let extendee = _values[a];
builder.addCase(extendee.id, stmts);
// Also alias each extender inheriting this exact overload
let extenders = classInstance.extenders;
if (extenders) {
for (let _values = Set_values(extenders), i = 0, k = _values.length; i < k; ++i) {
let extender = _values[i];
let instanceMembers = extender.prototype.instanceMembers;
if (instanceMembers && instanceMembers.has(instance.declaration.name.text)) {
continue; // skip those not inheriting
}
builder.addCase(extender.id, stmts);
}
}
}
}

// Call the original function if no other id matches and the method is not
// abstract or part of an interface. Note that doing so will not catch an
// invalid id, but can reduce code size significantly since we also don't
// have to add branches for extendees inheriting the original function.
// have to add branches for extenders inheriting the original function.
let body: ExpressionRef;
let instanceClass = instance.getBoundClassOrInterface();
if (!instance.is(CommonFlags.Abstract) && !(instanceClass && instanceClass.kind == ElementKind.Interface)) {
Expand Down Expand Up @@ -7432,7 +7444,7 @@ export class Compiler extends DiagnosticEmitter {
// <nullable> instanceof <nonNullable> - LHS must be != 0
if (actualType.isNullableReference && !expectedType.isNullableReference) {

// upcast - check statically
// same or upcast - check statically
if (actualType.nonNullableType.isAssignableTo(expectedType)) {
return module.binary(
sizeTypeRef == TypeRef.I64
Expand All @@ -7443,8 +7455,8 @@ export class Compiler extends DiagnosticEmitter {
);
}

// downcast - check dynamically
if (expectedType.isAssignableTo(actualType)) {
// potential downcast - check dynamically
if (actualType.nonNullableType.hasSubtypeAssignableTo(expectedType)) {
if (!(actualType.isUnmanaged || expectedType.isUnmanaged)) {
if (this.options.pedantic) {
this.pedantic(
Expand Down Expand Up @@ -7477,12 +7489,13 @@ export class Compiler extends DiagnosticEmitter {
// either none or both nullable
} else {

// upcast - check statically
// same or upcast - check statically
if (actualType.isAssignableTo(expectedType)) {
return module.maybeDropCondition(expr, module.i32(1));
}

// downcast - check dynamically
} else if (expectedType.isAssignableTo(actualType)) {
// potential downcast - check dynamically
if (actualType.hasSubtypeAssignableTo(expectedType)) {
if (!(actualType.isUnmanaged || expectedType.isUnmanaged)) {
let temp = flow.getTempLocal(actualType);
let tempIndex = temp.index;
Expand Down Expand Up @@ -7558,19 +7571,32 @@ export class Compiler extends DiagnosticEmitter {
), false // managedness is irrelevant here, isn't interrupted
)
);
let allInstances = new Set<Class>();
allInstances.add(instance);
instance.getAllExtendeesAndImplementers(allInstances);
for (let _values = Set_values(allInstances), i = 0, k = _values.length; i < k; ++i) {
let instance = _values[i];
stmts.push(
module.br("is_instance",
module.binary(BinaryOp.EqI32,
module.local_get(1, TypeRef.I32),
module.i32(instance.id)
let allInstances: Set<Class> | null;
if (instance.isInterface) {
allInstances = instance.implementers;
} else {
allInstances = new Set();
allInstances.add(instance);
let extenders = instance.extenders;
if (extenders) {
for (let _values = Set_values(extenders), i = 0, k = _values.length; i < k; ++i) {
let extender = _values[i];
allInstances.add(extender);
}
}
}
if (allInstances) {
for (let _values = Set_values(allInstances), i = 0, k = _values.length; i < k; ++i) {
let instance = _values[i];
stmts.push(
module.br("is_instance",
module.binary(BinaryOp.EqI32,
module.local_get(1, TypeRef.I32),
module.i32(instance.id)
)
)
)
);
);
}
}
stmts.push(
module.return(
Expand Down Expand Up @@ -7599,7 +7625,7 @@ export class Compiler extends DiagnosticEmitter {
if (classReference) {

// static check
if (classReference.extends(prototype)) {
if (classReference.extendsPrototype(prototype)) {

// <nullable> instanceof <PROTOTYPE> - LHS must be != 0
if (actualType.isNullableReference) {
Expand Down Expand Up @@ -7688,8 +7714,24 @@ export class Compiler extends DiagnosticEmitter {
let allInstances = new Set<Class>();
for (let _values = Map_values(instances), i = 0, k = _values.length; i < k; ++i) {
let instance = _values[i];
allInstances.add(instance);
instance.getAllExtendeesAndImplementers(allInstances);
if (instance.isInterface) {
let implementers = instance.implementers;
if (implementers) {
for (let _values = Set_values(implementers), i = 0, k = _values.length; i < k; ++i) {
let implementer = _values[i];
allInstances.add(implementer);
}
}
} else {
allInstances.add(instance);
let extenders = instance.extenders;
if (extenders) {
for (let _values = Set_values(extenders), i = 0, k = _values.length; i < k; ++i) {
let extender = _values[i];
allInstances.add(extender);
}
}
}
}
for (let _values = Set_values(allInstances), i = 0, k = _values.length; i < k; ++i) {
let instance = _values[i];
Expand Down Expand Up @@ -7946,7 +7988,7 @@ export class Compiler extends DiagnosticEmitter {
let parameterTypes = instance.signature.parameterTypes;
if (parameterTypes.length) {
let first = parameterTypes[0].getClass();
if (first && !first.extends(tsaArrayInstance.prototype)) {
if (first && !first.extendsPrototype(tsaArrayInstance.prototype)) {
arrayInstance = assert(this.resolver.resolveClass(this.program.arrayPrototype, [ stringType ]));
}
}
Expand Down Expand Up @@ -8014,7 +8056,7 @@ export class Compiler extends DiagnosticEmitter {

// handle static arrays
let contextualClass = contextualType.getClass();
if (contextualClass && contextualClass.extends(program.staticArrayPrototype)) {
if (contextualClass && contextualClass.extendsPrototype(program.staticArrayPrototype)) {
return this.compileStaticArrayLiteral(expression, contextualType, constraints);
}

Expand Down Expand Up @@ -8147,7 +8189,7 @@ export class Compiler extends DiagnosticEmitter {
): ExpressionRef {
let program = this.program;
let module = this.module;
assert(!arrayInstance.extends(program.staticArrayPrototype));
assert(!arrayInstance.extendsPrototype(program.staticArrayPrototype));
let elementType = arrayInstance.getArrayValueType(); // asserts

// __newArray(length, alignLog2, classId, staticBuffer)
Expand Down
Loading