Skip to content

Commit 26f2ffe

Browse files
committed
[PECO-728] Add OAuth support
Signed-off-by: Levko Kravets <[email protected]>
1 parent 39ba279 commit 26f2ffe

File tree

8 files changed

+476
-34
lines changed

8 files changed

+476
-34
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import http, { Server, IncomingMessage, ServerResponse } from 'http';
2+
import { BaseClient, generators } from 'openid-client';
3+
import open from 'open';
4+
import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger';
5+
6+
export interface AuthorizationCodeOptions {
7+
client: BaseClient;
8+
ports: Array<number>;
9+
logger?: IDBSQLLogger;
10+
}
11+
12+
const scopeDelimiter = ' ';
13+
14+
async function startServer(
15+
host: string,
16+
port: number,
17+
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
18+
): Promise<Server> {
19+
const server = http.createServer(requestHandler);
20+
21+
return new Promise((resolve, reject) => {
22+
const errorListener = (error: Error) => {
23+
server.off('error', errorListener);
24+
reject(error);
25+
};
26+
27+
server.on('error', errorListener);
28+
server.listen(port, host, () => {
29+
server.off('error', errorListener);
30+
resolve(server);
31+
});
32+
});
33+
}
34+
35+
async function stopServer(server: Server): Promise<void> {
36+
if (!server.listening) {
37+
return;
38+
}
39+
40+
return new Promise((resolve, reject) => {
41+
const errorListener = (error: Error) => {
42+
server.off('error', errorListener);
43+
reject(error);
44+
};
45+
46+
server.on('error', errorListener);
47+
server.close(() => {
48+
server.off('error', errorListener);
49+
resolve();
50+
});
51+
});
52+
}
53+
54+
export interface AuthorizationCodeFetchResult {
55+
code: string;
56+
verifier: string;
57+
redirectUri: string;
58+
}
59+
60+
export default class AuthorizationCode {
61+
private readonly client: BaseClient;
62+
private readonly host: string = 'localhost';
63+
private readonly ports: Array<number>;
64+
private readonly logger?: IDBSQLLogger;
65+
66+
constructor(options: AuthorizationCodeOptions) {
67+
this.client = options.client;
68+
this.ports = options.ports;
69+
this.logger = options.logger;
70+
}
71+
72+
public async fetch(scopes: Array<string>): Promise<AuthorizationCodeFetchResult> {
73+
const verifierString = generators.codeVerifier(32);
74+
const challengeString = generators.codeChallenge(verifierString);
75+
const state = generators.state(16);
76+
77+
let code: string | undefined = undefined;
78+
79+
const server = await this.startServer((req, res) => {
80+
const params = this.client.callbackParams(req);
81+
if (params.state === state) {
82+
code = params.code;
83+
res.writeHead(200);
84+
res.end('You can close this tab');
85+
server.stop();
86+
} else {
87+
res.writeHead(404);
88+
res.end();
89+
}
90+
});
91+
92+
let redirectUri = `http://${server.host}:${server.port}/`;
93+
const authUrl = this.client.authorizationUrl({
94+
response_type: 'code',
95+
response_mode: 'query',
96+
scope: scopes.join(scopeDelimiter),
97+
code_challenge: challengeString,
98+
code_challenge_method: 'S256',
99+
state,
100+
redirect_uri: redirectUri,
101+
});
102+
103+
await open(authUrl);
104+
await server.stopped();
105+
106+
if (!code) {
107+
throw new Error(`No path parameters were returned to the callback at ${redirectUri}`);
108+
}
109+
110+
return { code, verifier: verifierString, redirectUri };
111+
}
112+
113+
private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
114+
for (const port of this.ports) {
115+
const host = this.host;
116+
try {
117+
const server = await startServer(host, port, requestHandler);
118+
this.logger?.log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`);
119+
120+
let resolveStopped = () => {};
121+
let rejectStopped = (reason?: any) => {};
122+
const stoppedPromise = new Promise<void>((resolve, reject) => {
123+
resolveStopped = resolve;
124+
rejectStopped = reject;
125+
});
126+
127+
return {
128+
host,
129+
port,
130+
server,
131+
stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped),
132+
stopped: () => stoppedPromise,
133+
};
134+
} catch (error) {
135+
if (error instanceof Error && 'code' in error && error.code === 'EADDRINUSE') {
136+
this.logger?.log(LogLevel.debug, `Failed to start server at ${host}:${port}: ${error.code}`);
137+
continue; // try another port
138+
}
139+
throw error;
140+
}
141+
}
142+
143+
throw new Error('Failed to start server: all ports are in use');
144+
}
145+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import { Issuer, BaseClient } from 'openid-client';
2+
import HiveDriverError from '../../../errors/HiveDriverError';
3+
import IDBSQLLogger, { LogLevel } from '../../../contracts/IDBSQLLogger';
4+
import OAuthToken from "./OAuthToken";
5+
import AuthorizationCode from "./AuthorizationCode";
6+
7+
const oidcConfigPath = 'oidc/.well-known/oauth-authorization-server';
8+
9+
export interface OAuthManagerOptions {
10+
host: string;
11+
callbackPorts: Array<number>;
12+
clientId: string;
13+
logger?: IDBSQLLogger;
14+
}
15+
16+
export default class OAuthManager {
17+
private readonly options: OAuthManagerOptions;
18+
private readonly logger?: IDBSQLLogger;
19+
20+
private issuer?: Issuer;
21+
private client?: BaseClient;
22+
23+
constructor(options: OAuthManagerOptions) {
24+
this.options = options;
25+
this.logger = options.logger;
26+
}
27+
28+
private async getClient(): Promise<BaseClient> {
29+
if (!this.issuer) {
30+
const { host } = this.options;
31+
const schema = host.startsWith('https://') ? '' : 'https://';
32+
const trailingSlash = host.endsWith('/') ? '' : '/';
33+
this.issuer = await Issuer.discover(`${schema}${host}${trailingSlash}${oidcConfigPath}`);
34+
}
35+
36+
if (!this.client) {
37+
this.client = new this.issuer.Client({
38+
client_id: this.options.clientId,
39+
token_endpoint_auth_method: 'none',
40+
});
41+
}
42+
43+
return this.client;
44+
}
45+
46+
public async refreshAccessToken(token: OAuthToken): Promise<OAuthToken> {
47+
try {
48+
if (!token.hasExpired) {
49+
// The access token is fine. Just return it.
50+
return token;
51+
}
52+
} catch (error) {
53+
this.logger?.log(LogLevel.error, `${error}`);
54+
throw error;
55+
}
56+
57+
if (!token.refreshToken) {
58+
const message = `OAuth access token expired on ${token.expirationTime}.`;
59+
this.logger?.log(LogLevel.error, message);
60+
throw new HiveDriverError(message);
61+
}
62+
63+
// Try to refresh using the refresh token
64+
this.logger?.log(LogLevel.debug, `Attempting to refresh OAuth access token that expired on ${token.expirationTime}`);
65+
66+
const client = await this.getClient();
67+
const { access_token, refresh_token } = await client.refresh(token.refreshToken);
68+
if (!access_token || !refresh_token) {
69+
throw new Error('Failed to refresh token: invalid response');
70+
}
71+
return new OAuthToken(access_token, refresh_token);
72+
}
73+
74+
public async getToken(scopes: Array<string>): Promise<OAuthToken> {
75+
const client = await this.getClient();
76+
const authCode = new AuthorizationCode({
77+
client,
78+
ports: this.options.callbackPorts,
79+
logger: this.logger,
80+
});
81+
82+
const { code, verifier, redirectUri } = await authCode.fetch(scopes);
83+
84+
const { access_token, refresh_token } = await client.grant({
85+
grant_type: 'authorization_code',
86+
code,
87+
code_verifier: verifier,
88+
redirect_uri: redirectUri,
89+
});
90+
91+
if (!access_token) {
92+
throw new Error('Failed to fetch access token');
93+
}
94+
95+
return new OAuthToken(access_token, refresh_token);
96+
}
97+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import OAuthToken from "./OAuthToken";
2+
3+
export default class OAuthPersistence {
4+
public async persist(host: string, token: OAuthToken): Promise<void> {
5+
}
6+
7+
public async read(host: string): Promise<OAuthToken | undefined> {
8+
return undefined;
9+
}
10+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
export default class OAuthToken {
2+
private readonly _accessToken: string;
3+
private readonly _refreshToken?: string;
4+
private _expirationTime?: number;
5+
6+
constructor(accessToken: string, refreshToken?: string) {
7+
this._accessToken = accessToken;
8+
this._refreshToken = refreshToken;
9+
}
10+
11+
get accessToken(): string {
12+
return this._accessToken;
13+
}
14+
15+
get refreshToken(): string | undefined {
16+
return this._refreshToken;
17+
}
18+
19+
get expirationTime(): number {
20+
if (this._expirationTime === undefined) {
21+
const accessTokenPayload = Buffer.from(this._accessToken.split('.')[1], 'base64').toString('utf8');
22+
const decoded = JSON.parse(accessTokenPayload);
23+
this._expirationTime = Number(decoded['exp']);
24+
}
25+
return this._expirationTime;
26+
}
27+
28+
get hasExpired(): boolean {
29+
// This token has already been verified, and we are just parsing it.
30+
// If it has been tampered with, it will be rejected on the server side.
31+
// This avoids having to fetch the public key from the issuer and perform
32+
// an unnecessary signature verification.
33+
const now = Math.floor(Date.now() / 1000); // convert it to seconds
34+
return this.expirationTime <= now;
35+
}
36+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import IAuthentication from '../../contracts/IAuthentication';
2+
import ITransport from '../../contracts/ITransport';
3+
import IDBSQLLogger from '../../../contracts/IDBSQLLogger';
4+
import { AuthOptions } from '../../types/AuthOptions';
5+
import OAuthPersistence from './OAuthPersistence';
6+
import OAuthManager from "./OAuthManager";
7+
8+
interface DatabricksOAuthOptions extends AuthOptions {
9+
host: string;
10+
redirectPorts?: Array<number>;
11+
clientId?: string;
12+
scopes?: Array<string>;
13+
logger?: IDBSQLLogger;
14+
persistence?: OAuthPersistence;
15+
headers?: object;
16+
}
17+
18+
const defaultOAuthOptions = {
19+
clientId: 'databricks-sql-python',
20+
redirectPorts: [8020, 8021, 8022, 8023, 8024, 8025],
21+
scopes: ['sql', 'offline_access'],
22+
} satisfies Partial<DatabricksOAuthOptions>;
23+
24+
export default class DatabricksOAuth implements IAuthentication {
25+
private readonly host: string;
26+
private readonly redirectPorts: Array<number>;
27+
private readonly clientId: string;
28+
private readonly scopes: Array<string>;
29+
private readonly logger?: IDBSQLLogger;
30+
private readonly persistence?: OAuthPersistence;
31+
private readonly headers?: object;
32+
33+
private readonly manager: OAuthManager;
34+
35+
constructor(options: DatabricksOAuthOptions) {
36+
this.host = options.host;
37+
this.redirectPorts = options.redirectPorts || defaultOAuthOptions.redirectPorts;
38+
this.clientId = options.clientId || defaultOAuthOptions.clientId;
39+
this.scopes = options.scopes || defaultOAuthOptions.scopes;
40+
this.logger = options.logger;
41+
this.persistence = options.persistence;
42+
this.headers = options.headers;
43+
44+
this.manager = new OAuthManager({
45+
host: this.host,
46+
callbackPorts: this.redirectPorts,
47+
clientId: this.clientId,
48+
logger: this.logger,
49+
});
50+
}
51+
52+
async authenticate(transport: ITransport): Promise<ITransport> {
53+
let token = await this.persistence?.read(this.host);
54+
if (!token) {
55+
token = await this.manager.getToken(this.scopes);
56+
}
57+
58+
token = await this.manager.refreshAccessToken(token);
59+
await this.persistence?.persist(this.host, token);
60+
61+
transport.setOptions('headers', {
62+
...this.headers,
63+
Authorization: `Bearer ${token.accessToken}`,
64+
});
65+
66+
return transport;
67+
}
68+
}

lib/connection/auth/helpers/SaslPackageFactory.ts

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)