Skip to content
Merged
1 change: 1 addition & 0 deletions .eslintrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"@typescript-eslint/restrict-template-expressions": "off",
"@typescript-eslint/unbound-method": "off",
"no-empty": "off",
"prefer-const": ["error", { "destructuring": "all" }],
"prefer-rest-params": "off",
"prefer-spread": "off"
},
Expand Down
35 changes: 35 additions & 0 deletions src/Disposable.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

/**
* Based off of VS Code
* https://github.com/microsoft/vscode/blob/a64e8e5673a44e5b9c2d493666bde684bd5a135c/src/vs/workbench/api/common/extHostTypes.ts#L32
*/
export class Disposable {
static from(...inDisposables: { dispose(): any }[]): Disposable {
let disposables: ReadonlyArray<{ dispose(): any }> | undefined = inDisposables;
return new Disposable(function () {
if (disposables) {
for (const disposable of disposables) {
if (disposable && typeof disposable.dispose === 'function') {
disposable.dispose();
}
}
disposables = undefined;
}
});
}

#callOnDispose?: () => any;

constructor(callOnDispose: () => any) {
this.#callOnDispose = callOnDispose;
}

dispose(): any {
if (this.#callOnDispose instanceof Function) {
this.#callOnDispose();
this.#callOnDispose = undefined;
}
}
}
2 changes: 2 additions & 0 deletions src/Worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import * as parseArgs from 'minimist';
import { FunctionLoader } from './FunctionLoader';
import { CreateGrpcEventStream } from './GrpcClient';
import { setupCoreModule } from './setupCoreModule';
import { setupEventStream } from './setupEventStream';
import { ensureErrorType } from './utils/ensureErrorType';
import { InternalException } from './utils/InternalException';
Expand Down Expand Up @@ -42,6 +43,7 @@ export function startNodeWorker(args) {

const channel = new WorkerChannel(eventStream, new FunctionLoader());
setupEventStream(workerId, channel);
setupCoreModule(channel);

eventStream.write({
requestId: requestId,
Expand Down
54 changes: 25 additions & 29 deletions src/WorkerChannel.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

import { Context } from '@azure/functions';
import { HookCallback, HookContext } from '@azure/functions-core';
import { readJson } from 'fs-extra';
import { AzureFunctionsRpcMessages as rpc } from '../azure-functions-language-worker-protobuf/src/rpc';
import { Disposable } from './Disposable';
import { IFunctionLoader } from './FunctionLoader';
import { IEventStream } from './GrpcClient';
import { ensureErrorType } from './utils/ensureErrorType';
import path = require('path');
import LogLevel = rpc.RpcLog.Level;
import LogCategory = rpc.RpcLog.RpcLogCategory;

type InvocationRequestBefore = (context: Context, userFn: Function) => Function;
type InvocationRequestAfter = (context: Context) => void;

export interface PackageJson {
type?: string;
}
Expand All @@ -22,15 +20,13 @@ export class WorkerChannel {
public eventStream: IEventStream;
public functionLoader: IFunctionLoader;
public packageJson: PackageJson;
private _invocationRequestBefore: InvocationRequestBefore[];
private _invocationRequestAfter: InvocationRequestAfter[];
#preInvocationHooks: HookCallback[] = [];
#postInvocationHooks: HookCallback[] = [];

constructor(eventStream: IEventStream, functionLoader: IFunctionLoader) {
this.eventStream = eventStream;
this.functionLoader = functionLoader;
this.packageJson = {};
this._invocationRequestBefore = [];
this._invocationRequestAfter = [];
}

/**
Expand All @@ -44,32 +40,32 @@ export class WorkerChannel {
});
}

/**
* Register a patching function to be run before User Function is executed.
* Hook should return a patched version of User Function.
*/
public registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void {
this._invocationRequestBefore.push(beforeCb);
}

/**
* Register a function to be run after User Function resolves.
*/
public registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void {
this._invocationRequestAfter.push(afterCb);
public registerHook(hookName: string, callback: HookCallback): Disposable {
const hooks = this.#getHooks(hookName);
hooks.push(callback);
return new Disposable(() => {
const index = hooks.indexOf(callback);
if (index > -1) {
hooks.splice(index, 1);
}
});
}

public runInvocationRequestBefore(context: Context, userFunction: Function): Function {
let wrappedFunction = userFunction;
for (const before of this._invocationRequestBefore) {
wrappedFunction = before(context, wrappedFunction);
public async executeHooks(hookName: string, context: HookContext): Promise<void> {
const callbacks = this.#getHooks(hookName);
for (const callback of callbacks) {
await callback(context);
}
return wrappedFunction;
}

public runInvocationRequestAfter(context: Context) {
for (const after of this._invocationRequestAfter) {
after(context);
#getHooks(hookName: string): HookCallback[] {
switch (hookName) {
case 'preInvocation':
return this.#preInvocationHooks;
case 'postInvocation':
return this.#postInvocationHooks;
default:
throw new RangeError(`Unrecognized hook "${hookName}"`);
}
}

Expand Down
33 changes: 27 additions & 6 deletions src/eventHandlers/invocationRequest.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

import { HookData, PostInvocationContext, PreInvocationContext } from '@azure/functions-core';
import { format } from 'util';
import { AzureFunctionsRpcMessages as rpc } from '../../azure-functions-language-worker-protobuf/src/rpc';
import { CreateContextAndInputs } from '../Context';
Expand Down Expand Up @@ -67,7 +68,7 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
isDone = true;
}

const { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
let { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
try {
const legacyDoneTask = new Promise((resolve, reject) => {
doneEmitter.on('done', (err?: unknown, result?: any) => {
Expand All @@ -80,8 +81,13 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
});
});

let userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
userFunction = channel.runInvocationRequestBefore(context, userFunction);
const hookData: HookData = {};
const userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
const preInvocContext: PreInvocationContext = { hookData, invocationContext: context, inputs };

await channel.executeHooks('preInvocation', preInvocContext);
inputs = preInvocContext.inputs;

let rawResult = userFunction(context, ...inputs);
resultIsPromise = rawResult && typeof rawResult.then === 'function';
let resultTask: Promise<any>;
Expand All @@ -95,7 +101,24 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
resultTask = legacyDoneTask;
}

const result = await resultTask;
const postInvocContext: PostInvocationContext = {
hookData,
invocationContext: context,
inputs,
result: null,
error: null,
};
try {
postInvocContext.result = await resultTask;
} catch (err) {
postInvocContext.error = err;
}
await channel.executeHooks('postInvocation', postInvocContext);

if (isError(postInvocContext.error)) {
throw postInvocContext.error;
}
const result = postInvocContext.result;

// Allow HTTP response from context.res if HTTP response is not defined from the context.bindings object
if (info.httpOutputName && context.res && context.bindings[info.httpOutputName] === undefined) {
Expand Down Expand Up @@ -164,6 +187,4 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
requestId: requestId,
invocationResponse: response,
});

channel.runInvocationRequestAfter(context);
}
29 changes: 29 additions & 0 deletions src/setupCoreModule.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

import { HookCallback } from '@azure/functions-core';
import { Disposable } from './Disposable';
import { WorkerChannel } from './WorkerChannel';
import Module = require('module');

/**
* Intercepts the default "require" method so that we can provide our own "built-in" module
* This module is essentially the publicly accessible API for our worker
* This module is available to users only at runtime, not as an installable npm package
*/
export function setupCoreModule(channel: WorkerChannel): void {
const coreApi = {
registerHook: (hookName: string, callback: HookCallback) => channel.registerHook(hookName, callback),
Disposable,
};

Module.prototype.require = new Proxy(Module.prototype.require, {
apply(target, thisArg, argArray) {
if (argArray[0] === '@azure/functions-core') {
return coreApi;
} else {
return Reflect.apply(target, thisArg, argArray);
}
},
});
}
2 changes: 2 additions & 0 deletions test/eventHandlers/beforeEventHandlerSuite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import * as sinon from 'sinon';
import { FunctionLoader } from '../../src/FunctionLoader';
import { setupCoreModule } from '../../src/setupCoreModule';
import { setupEventStream } from '../../src/setupEventStream';
import { WorkerChannel } from '../../src/WorkerChannel';
import { TestEventStream } from './TestEventStream';
Expand All @@ -12,5 +13,6 @@ export function beforeEventHandlerSuite() {
const loader = sinon.createStubInstance<FunctionLoader>(FunctionLoader);
const channel = new WorkerChannel(stream, loader);
setupEventStream('workerId', channel);
setupCoreModule(channel);
return { stream, loader, channel };
}
Loading