From 9771efe449e8696672a0213226daedddaad2db26 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 22 Jun 2019 12:25:50 -0400 Subject: [PATCH 01/12] save --- src/backends/backend.ts | 14 ++++----- src/backends/cpu/backend_cpu.ts | 18 +++++++---- src/backends/cpu/backend_cpu_test.ts | 23 ++++++++++---- src/backends/webgl/backend_webgl.ts | 12 ++++---- src/backends/webgl/backend_webgl_test.ts | 21 +++++++++---- src/backends/webgl/tex_util.ts | 4 +-- src/engine.ts | 11 +++---- src/engine_test.ts | 2 +- src/tensor.ts | 38 ++++++++++++++++-------- src/types.ts | 1 + src/util.ts | 4 +-- src/util_test.ts | 12 +++++--- 12 files changed, 105 insertions(+), 55 deletions(-) diff --git a/src/backends/backend.ts b/src/backends/backend.ts index 2b15d2e263..0819d0fcbd 100644 --- a/src/backends/backend.ts +++ b/src/backends/backend.ts @@ -18,7 +18,7 @@ import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util'; import {Activation} from '../ops/fused_util'; import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor'; -import {DataType, DataValues, PixelData, Rank, ShapeMap} from '../types'; +import {BackendDataValues, DataType, PixelData, Rank, ShapeMap} from '../types'; export const EPSILON_FLOAT32 = 1e-7; export const EPSILON_FLOAT16 = 1e-4; @@ -31,10 +31,10 @@ export interface BackendTimingInfo { } export interface TensorStorage { - read(dataId: DataId): Promise; - readSync(dataId: DataId): DataValues; + read(dataId: DataId): Promise; + readSync(dataId: DataId): BackendDataValues; disposeData(dataId: DataId): void; - write(dataId: DataId, values: DataValues): void; + write(dataId: DataId, values: BackendDataValues): void; fromPixels( pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement| HTMLVideoElement, @@ -92,16 +92,16 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { time(f: () => void): Promise { throw new Error('Not yet implemented.'); } - read(dataId: object): Promise { + read(dataId: object): Promise { throw new Error('Not yet implemented.'); } - readSync(dataId: object): DataValues { + readSync(dataId: object): BackendDataValues { throw new Error('Not yet implemented.'); } disposeData(dataId: object): void { throw new Error('Not yet implemented.'); } - write(dataId: object, values: DataValues): void { + write(dataId: object, values: BackendDataValues): void { throw new Error('Not yet implemented.'); } fromPixels( diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index 2aae7a6395..ea36450930 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -34,7 +34,7 @@ import * as scatter_nd_util from '../../ops/scatter_nd_util'; import * as selu_util from '../../ops/selu_util'; import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util'; import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; -import {DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; +import {BackendDataValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend'; @@ -58,7 +58,7 @@ function mapActivation( } interface TensorData { - values?: DataTypeMap[D]; + values?: BackendDataValues; dtype: D; // For complex numbers, the real and imaginary parts are stored as their own // individual tensors, with a parent joining the two with the @@ -116,7 +116,7 @@ export class MathBackendCPU implements KernelBackend { } this.data.set(dataId, {dtype}); } - write(dataId: DataId, values: DataValues): void { + write(dataId: DataId, values: BackendDataValues): void { if (values == null) { throw new Error('MathBackendCPU.write(): values can not be null'); } @@ -186,10 +186,10 @@ export class MathBackendCPU implements KernelBackend { [pixels.height, pixels.width, numChannels]; return tensor3d(values, outShape, 'int32'); } - async read(dataId: DataId): Promise { + async read(dataId: DataId): Promise { return this.readSync(dataId); } - readSync(dataId: DataId): DataValues { + readSync(dataId: DataId): BackendDataValues { const {dtype, complexTensors} = this.data.get(dataId); if (dtype === 'complex64') { const realValues = @@ -202,7 +202,13 @@ export class MathBackendCPU implements KernelBackend { } private bufferSync(t: Tensor): TensorBuffer { - return buffer(t.shape, t.dtype, this.readSync(t.dataId)) as TensorBuffer; + const data = this.readSync(t.dataId); + let decodedData = data as DataValues; + if (t.dtype === 'string') { + // Decode the bytes into string. + decodedData = (data as Uint8Array[]).map(d => ENV.platform.decodeUTF8(d)); + } + return buffer(t.shape, t.dtype, decodedData) as TensorBuffer; } disposeData(dataId: DataId): void { diff --git a/src/backends/cpu/backend_cpu_test.ts b/src/backends/cpu/backend_cpu_test.ts index 8a6c7cb4ce..61387c6c6e 100644 --- a/src/backends/cpu/backend_cpu_test.ts +++ b/src/backends/cpu/backend_cpu_test.ts @@ -23,6 +23,14 @@ import {expectArraysClose, expectArraysEqual} from '../../test_util'; import {MathBackendCPU} from './backend_cpu'; import {CPU_ENVS} from './backend_cpu_test_registry'; +function encode(data: string[]): Uint8Array[] { + return data.map(s => tf.ENV.platform.encodeUTF8(s)); +} + +function decode(data: Uint8Array[]): string[] { + return data.map(d => tf.ENV.platform.decodeUTF8(d)); +} + describeWithFlags('backendCPU', CPU_ENVS, () => { let backend: MathBackendCPU; beforeEach(() => { @@ -36,19 +44,22 @@ describeWithFlags('backendCPU', CPU_ENVS, () => { it('register empty string tensor and write', () => { const t = tf.Tensor.make([3], {}, 'string'); - backend.write(t.dataId, ['c', 'a', 'b']); - expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']); + backend.write(t.dataId, encode(['c', 'a', 'b'])); + expectArraysEqual( + decode(backend.readSync(t.dataId) as Uint8Array[]), ['c', 'a', 'b']); }); it('register string tensor with values', () => { const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string'); - expectArraysEqual(backend.readSync(t.dataId), ['a', 'b', 'c']); + expectArraysEqual( + decode(backend.readSync(t.dataId) as Uint8Array[]), ['a', 'b', 'c']); }); it('register string tensor with values and overwrite', () => { const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string'); - backend.write(t.dataId, ['c', 'a', 'b']); - expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']); + backend.write(t.dataId, encode(['c', 'a', 'b'])); + expectArraysEqual( + decode(backend.readSync(t.dataId) as Uint8Array[]), ['c', 'a', 'b']); }); it('register string tensor with values and mismatched shape', () => { @@ -129,7 +140,7 @@ describeWithFlags('memory cpu', CPU_ENVS, () => { const mem = tf.memory(); expect(mem.numTensors).toBe(2); expect(mem.numDataBuffers).toBe(2); - expect(mem.numBytes).toBe(6); + expect(mem.numBytes).toBe(5); expect(mem.unreliable).toBe(true); const expectedReasonGC = diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index 284f8a5ff4..9d9e14056a 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -37,7 +37,7 @@ import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../o import {softmax} from '../../ops/softmax'; import {range, scalar, tensor} from '../../ops/tensor_ops'; import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor'; -import {DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types'; +import {BackendDataValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util'; import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend'; @@ -338,7 +338,7 @@ export class MathBackendWebGL implements KernelBackend { return {dataId, shape, dtype}; } - write(dataId: DataId, values: DataValues): void { + write(dataId: DataId, values: BackendDataValues): void { if (values == null) { throw new Error('MathBackendWebGL.write(): values can not be null'); } @@ -365,7 +365,7 @@ export class MathBackendWebGL implements KernelBackend { texData.values = values; } - readSync(dataId: DataId): DataValues { + readSync(dataId: DataId): BackendDataValues { const texData = this.texData.get(dataId); const {values, dtype, complexTensors, slice, shape} = texData; if (slice != null) { @@ -402,7 +402,7 @@ export class MathBackendWebGL implements KernelBackend { return this.convertAndCacheOnCPU(dataId, result); } - async read(dataId: DataId): Promise { + async read(dataId: DataId): Promise { if (this.pendingRead.has(dataId)) { const subscribers = this.pendingRead.get(dataId); return new Promise(resolve => subscribers.push(resolve)); @@ -952,7 +952,9 @@ export class MathBackendWebGL implements KernelBackend { tile(x: T, reps: number[]): T { if (x.dtype === 'string') { - const buf = buffer(x.shape, x.dtype, this.readSync(x.dataId) as string[]); + const data = this.readSync(x.dataId) as Uint8Array[]; + const decodedData = data.map(d => ENV.platform.decodeUTF8(d)); + const buf = buffer(x.shape, x.dtype, decodedData); return tile(buf, reps) as T; } const program = new TileProgram(x.shape, reps); diff --git a/src/backends/webgl/backend_webgl_test.ts b/src/backends/webgl/backend_webgl_test.ts index bb6e906d2d..3aa048d498 100644 --- a/src/backends/webgl/backend_webgl_test.ts +++ b/src/backends/webgl/backend_webgl_test.ts @@ -99,6 +99,14 @@ describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => { }); }); +function encode(data: string[]): Uint8Array[] { + return data.map(s => tf.ENV.platform.encodeUTF8(s)); +} + +function decode(data: Uint8Array[]): string[] { + return data.map(d => tf.ENV.platform.decodeUTF8(d)); +} + describeWithFlags('backendWebGL', WEBGL_ENVS, () => { let prevBackend: string; @@ -126,8 +134,9 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => { tf.setBackend('test-storage'); const t = tf.Tensor.make([3], {}, 'string'); - backend.write(t.dataId, ['c', 'a', 'b']); - expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']); + backend.write(t.dataId, encode(['c', 'a', 'b'])); + expectArraysEqual( + decode(backend.readSync(t.dataId) as Uint8Array[]), ['c', 'a', 'b']); }); it('register string tensor with values', () => { @@ -136,7 +145,8 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => { tf.setBackend('test-storage'); const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string'); - expectArraysEqual(backend.readSync(t.dataId), ['a', 'b', 'c']); + expectArraysEqual( + decode(backend.readSync(t.dataId) as Uint8Array[]), ['a', 'b', 'c']); }); it('register string tensor with values and overwrite', () => { @@ -145,8 +155,9 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => { tf.setBackend('test-storage'); const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string'); - backend.write(t.dataId, ['c', 'a', 'b']); - expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']); + backend.write(t.dataId, encode(['c', 'a', 'b'])); + expectArraysEqual( + decode(backend.readSync(t.dataId) as Uint8Array[]), ['c', 'a', 'b']); }); it('register string tensor with values and wrong shape throws error', () => { diff --git a/src/backends/webgl/tex_util.ts b/src/backends/webgl/tex_util.ts index 6518190e0b..b8e06f3000 100644 --- a/src/backends/webgl/tex_util.ts +++ b/src/backends/webgl/tex_util.ts @@ -16,7 +16,7 @@ */ import {DataId, Tensor} from '../../tensor'; -import {DataType, DataValues} from '../../types'; +import {BackendDataValues, DataType} from '../../types'; import * as util from '../../util'; export enum TextureUsage { @@ -40,7 +40,7 @@ export interface TextureData { dtype: DataType; // Optional. - values?: DataValues; + values?: BackendDataValues; texture?: WebGLTexture; // For complex numbers, the real and imaginary parts are stored as their own // individual tensors, with a parent joining the two with the diff --git a/src/engine.ts b/src/engine.ts index 46abda9d92..a018e3fbbb 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -22,7 +22,7 @@ import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode import {DataId, setTensorTracker, Tensor, Tensor3D, TensorTracker, Variable} from './tensor'; import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types'; import {getTensorsInContainer} from './tensor_util'; -import {DataType, DataValues, PixelData} from './types'; +import {BackendDataValues, DataType, PixelData} from './types'; import * as util from './util'; import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util'; @@ -830,7 +830,8 @@ export class Engine implements TensorManager, TensorTracker, DataMover { } // Forwarding to backend. - write(destBackend: KernelBackend, dataId: DataId, values: DataValues): void { + write(destBackend: KernelBackend, dataId: DataId, values: BackendDataValues): + void { const info = this.state.tensorInfo.get(dataId); const srcBackend = info.backend; @@ -838,7 +839,7 @@ export class Engine implements TensorManager, TensorTracker, DataMover { // Bytes for string tensors are counted when writing. if (info.dtype === 'string') { - const newBytes = bytesFromStringArray(values as string[]); + const newBytes = bytesFromStringArray(values as Uint8Array[]); this.state.numBytes += newBytes - info.bytes; info.bytes = newBytes; } @@ -852,12 +853,12 @@ export class Engine implements TensorManager, TensorTracker, DataMover { } destBackend.write(dataId, values); } - readSync(dataId: DataId): DataValues { + readSync(dataId: DataId): BackendDataValues { // Route the read to the correct backend. const info = this.state.tensorInfo.get(dataId); return info.backend.readSync(dataId); } - read(dataId: DataId): Promise { + read(dataId: DataId): Promise { // Route the read to the correct backend. const info = this.state.tensorInfo.get(dataId); return info.backend.read(dataId); diff --git a/src/engine_test.ts b/src/engine_test.ts index 735dc5cd38..d173614f1f 100644 --- a/src/engine_test.ts +++ b/src/engine_test.ts @@ -349,7 +349,7 @@ describeWithFlags('memory', ALL_ENVS, () => { const a = tf.tensor([['a', 'bb'], ['c', 'd']]); expect(tf.memory().numTensors).toBe(1); - expect(tf.memory().numBytes).toBe(10); // 5 letters, each 2 bytes. + expect(tf.memory().numBytes).toBe(5); // 5 letters, each 1 byte in utf8. a.dispose(); diff --git a/src/tensor.ts b/src/tensor.ts index a115f97585..e5b5b7c63a 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -15,8 +15,9 @@ * ============================================================================= */ +import {ENV} from './environment'; import {tensorToString} from './tensor_format'; -import {ArrayMap, DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D} from './types'; +import {ArrayMap, BackendDataValues, DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D} from './types'; import * as util from './util'; import {computeStrides, toNestedArray} from './util'; @@ -28,10 +29,10 @@ export interface TensorData { // This interface mimics KernelBackend (in backend.ts), which would create a // circular dependency if imported. export interface Backend { - read(dataId: object): Promise; - readSync(dataId: object): DataValues; + read(dataId: object): Promise; + readSync(dataId: object): BackendDataValues; disposeData(dataId: object): void; - write(dataId: object, values: DataValues): void; + write(dataId: object, values: BackendDataValues): void; } /** @@ -159,9 +160,9 @@ export interface TensorTracker { registerTensor(t: Tensor, backend?: Backend): void; disposeTensor(t: Tensor): void; disposeVariable(v: Variable): void; - write(backend: Backend, dataId: DataId, values: DataValues): void; - read(dataId: DataId): Promise; - readSync(dataId: DataId): DataValues; + write(backend: Backend, dataId: DataId, values: BackendDataValues): void; + read(dataId: DataId): Promise; + readSync(dataId: DataId): BackendDataValues; registerVariable(v: Variable): void; nextTensorId(): number; nextVariableId(): number; @@ -452,8 +453,8 @@ export class Tensor { readonly strides: number[]; protected constructor( - shape: ShapeMap[R], dtype: DataType, values?: DataValues, dataId?: DataId, - backend?: Backend) { + shape: ShapeMap[R], dtype: DataType, values?: DataValues|Uint8Array[], + dataId?: DataId, backend?: Backend) { this.shape = shape.slice() as ShapeMap[R]; this.dtype = dtype || 'float32'; this.size = util.sizeFromShape(shape); @@ -463,7 +464,11 @@ export class Tensor { this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher') as R; trackerFn().registerTensor(this, backend); if (values != null) { - trackerFn().write(backend, this.dataId, values); + if (dtype === 'string' && util.isString(values) || + util.isString(values[0])) { + values = (values as string[]).map(v => ENV.platform.encodeUTF8(v)); + } + trackerFn().write(backend, this.dataId, values as BackendDataValues); } } @@ -610,7 +615,12 @@ export class Tensor { /** @doc {heading: 'Tensors', subheading: 'Classes'} */ async data(): Promise { this.throwIfDisposed(); - return trackerFn().read(this.dataId) as Promise; + const data = trackerFn().read(this.dataId); + if (this.dtype === 'string') { + const bytes = await data as Uint8Array[]; + return bytes.map(d => ENV.platform.decodeUTF8(d)); + } + return data as Promise; } /** @@ -620,7 +630,11 @@ export class Tensor { /** @doc {heading: 'Tensors', subheading: 'Classes'} */ dataSync(): DataTypeMap[D] { this.throwIfDisposed(); - return trackerFn().readSync(this.dataId) as DataTypeMap[D]; + const data = trackerFn().readSync(this.dataId); + if (this.dtype === 'string') { + return (data as Uint8Array[]).map(d => ENV.platform.decodeUTF8(d)); + } + return data as DataTypeMap[D]; } /** diff --git a/src/types.ts b/src/types.ts index 6cc8393096..c6fb3641f3 100644 --- a/src/types.ts +++ b/src/types.ts @@ -58,6 +58,7 @@ export type DataType = keyof DataTypeMap; export type NumericDataType = 'float32'|'int32'|'bool'|'complex64'; export type TypedArray = Float32Array|Int32Array|Uint8Array; export type DataValues = DataTypeMap[DataType]; +export type BackendDataValues = Float32Array|Int32Array|Uint8Array|Uint8Array[]; export enum Rank { R0 = 'R0', diff --git a/src/util.ts b/src/util.ts index ee5c328409..2aff5a82db 100644 --- a/src/util.ts +++ b/src/util.ts @@ -476,12 +476,12 @@ export function bytesPerElement(dtype: DataType): number { * not possible since it depends on the encoding of the html page that serves * the website. */ -export function bytesFromStringArray(arr: string[]): number { +export function bytesFromStringArray(arr: Uint8Array[]): number { if (arr == null) { return 0; } let bytes = 0; - arr.forEach(x => bytes += x.length * 2); + arr.forEach(x => bytes += x.length); return bytes; } diff --git a/src/util_test.ts b/src/util_test.ts index f131e93264..8b4d23be95 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -114,11 +114,15 @@ describe('util.flatten', () => { }); }); +function encode(data: string[]): Uint8Array[] { + return data.map(d => ENV.platform.encodeUTF8(d)); +} + describe('util.bytesFromStringArray', () => { - it('count each character as 2 bytes', () => { - expect(util.bytesFromStringArray(['a', 'bb', 'ccc'])).toBe(6 * 2); - expect(util.bytesFromStringArray(['a', 'bb', 'cccddd'])).toBe(9 * 2); - expect(util.bytesFromStringArray(['даниел'])).toBe(6 * 2); + it('count bytes after utf8 encoding', () => { + expect(util.bytesFromStringArray(encode(['a', 'bb', 'ccc']))).toBe(6); + expect(util.bytesFromStringArray(encode(['a', 'bb', 'cccddd']))).toBe(9); + expect(util.bytesFromStringArray(encode(['даниел']))).toBe(6 * 2); }); }); From ba4b224c2a8b532da2edc68b97c8b87d0f8c8102 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 22 Jun 2019 14:36:10 -0400 Subject: [PATCH 02/12] save --- src/backends/cpu/backend_cpu.ts | 4 ++-- src/backends/cpu/backend_cpu_test.ts | 22 +++++++---------- src/backends/webgl/backend_webgl.ts | 2 +- src/backends/webgl/backend_webgl_test.ts | 23 ++++++++---------- src/io/io_utils.ts | 3 ++- src/ops/tensor_ops.ts | 6 +++-- src/platforms/platform.ts | 4 ++-- src/platforms/platform_browser.ts | 9 ++++--- src/platforms/platform_browser_test.ts | 20 +++++++++------- src/platforms/platform_node.ts | 17 +++++++------- src/platforms/platform_node_test.ts | 20 +++++++++------- src/tensor.ts | 30 ++++++++++++++---------- src/tensor_test.ts | 9 +++++++ src/util.ts | 18 ++++++++++++++ src/util_test.ts | 13 +++++----- 15 files changed, 117 insertions(+), 83 deletions(-) diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index ea36450930..81ba3d5c62 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -36,7 +36,7 @@ import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../o import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; import {BackendDataValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; -import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; +import {decodeStrings, getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; @@ -206,7 +206,7 @@ export class MathBackendCPU implements KernelBackend { let decodedData = data as DataValues; if (t.dtype === 'string') { // Decode the bytes into string. - decodedData = (data as Uint8Array[]).map(d => ENV.platform.decodeUTF8(d)); + decodedData = decodeStrings(data as Uint8Array[], t.encoding); } return buffer(t.shape, t.dtype, decodedData) as TensorBuffer; } diff --git a/src/backends/cpu/backend_cpu_test.ts b/src/backends/cpu/backend_cpu_test.ts index 61387c6c6e..77e48fc5d4 100644 --- a/src/backends/cpu/backend_cpu_test.ts +++ b/src/backends/cpu/backend_cpu_test.ts @@ -19,18 +19,11 @@ import * as tf from '../../index'; import {describeWithFlags} from '../../jasmine_util'; import {tensor2d} from '../../ops/ops'; import {expectArraysClose, expectArraysEqual} from '../../test_util'; +import {decodeStrings, encodeStrings} from '../../util'; import {MathBackendCPU} from './backend_cpu'; import {CPU_ENVS} from './backend_cpu_test_registry'; -function encode(data: string[]): Uint8Array[] { - return data.map(s => tf.ENV.platform.encodeUTF8(s)); -} - -function decode(data: Uint8Array[]): string[] { - return data.map(d => tf.ENV.platform.decodeUTF8(d)); -} - describeWithFlags('backendCPU', CPU_ENVS, () => { let backend: MathBackendCPU; beforeEach(() => { @@ -44,22 +37,25 @@ describeWithFlags('backendCPU', CPU_ENVS, () => { it('register empty string tensor and write', () => { const t = tf.Tensor.make([3], {}, 'string'); - backend.write(t.dataId, encode(['c', 'a', 'b'])); + backend.write(t.dataId, encodeStrings(['c', 'a', 'b'])); expectArraysEqual( - decode(backend.readSync(t.dataId) as Uint8Array[]), ['c', 'a', 'b']); + decodeStrings(backend.readSync(t.dataId) as Uint8Array[]), + ['c', 'a', 'b']); }); it('register string tensor with values', () => { const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string'); expectArraysEqual( - decode(backend.readSync(t.dataId) as Uint8Array[]), ['a', 'b', 'c']); + decodeStrings(backend.readSync(t.dataId) as Uint8Array[]), + ['a', 'b', 'c']); }); it('register string tensor with values and overwrite', () => { const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string'); - backend.write(t.dataId, encode(['c', 'a', 'b'])); + backend.write(t.dataId, encodeStrings(['c', 'a', 'b'])); expectArraysEqual( - decode(backend.readSync(t.dataId) as Uint8Array[]), ['c', 'a', 'b']); + decodeStrings(backend.readSync(t.dataId) as Uint8Array[]), + ['c', 'a', 'b']); }); it('register string tensor with values and mismatched shape', () => { diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index 9d9e14056a..76d1bca189 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -953,7 +953,7 @@ export class MathBackendWebGL implements KernelBackend { tile(x: T, reps: number[]): T { if (x.dtype === 'string') { const data = this.readSync(x.dataId) as Uint8Array[]; - const decodedData = data.map(d => ENV.platform.decodeUTF8(d)); + const decodedData = util.decodeStrings(data, x.encoding); const buf = buffer(x.shape, x.dtype, decodedData); return tile(buf, reps) as T; } diff --git a/src/backends/webgl/backend_webgl_test.ts b/src/backends/webgl/backend_webgl_test.ts index 3aa048d498..95e33be99e 100644 --- a/src/backends/webgl/backend_webgl_test.ts +++ b/src/backends/webgl/backend_webgl_test.ts @@ -18,6 +18,8 @@ import * as tf from '../../index'; import {describeWithFlags} from '../../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../../test_util'; +import {decodeStrings, encodeStrings} from '../../util'; + import {MathBackendWebGL, WebGLMemoryInfo} from './backend_webgl'; import {WEBGL_ENVS} from './backend_webgl_test_registry'; @@ -99,14 +101,6 @@ describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => { }); }); -function encode(data: string[]): Uint8Array[] { - return data.map(s => tf.ENV.platform.encodeUTF8(s)); -} - -function decode(data: Uint8Array[]): string[] { - return data.map(d => tf.ENV.platform.decodeUTF8(d)); -} - describeWithFlags('backendWebGL', WEBGL_ENVS, () => { let prevBackend: string; @@ -134,9 +128,10 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => { tf.setBackend('test-storage'); const t = tf.Tensor.make([3], {}, 'string'); - backend.write(t.dataId, encode(['c', 'a', 'b'])); + backend.write(t.dataId, encodeStrings(['c', 'a', 'b'])); expectArraysEqual( - decode(backend.readSync(t.dataId) as Uint8Array[]), ['c', 'a', 'b']); + decodeStrings(backend.readSync(t.dataId) as Uint8Array[]), + ['c', 'a', 'b']); }); it('register string tensor with values', () => { @@ -146,7 +141,8 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => { const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string'); expectArraysEqual( - decode(backend.readSync(t.dataId) as Uint8Array[]), ['a', 'b', 'c']); + decodeStrings(backend.readSync(t.dataId) as Uint8Array[]), + ['a', 'b', 'c']); }); it('register string tensor with values and overwrite', () => { @@ -155,9 +151,10 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => { tf.setBackend('test-storage'); const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string'); - backend.write(t.dataId, encode(['c', 'a', 'b'])); + backend.write(t.dataId, encodeStrings(['c', 'a', 'b'])); expectArraysEqual( - decode(backend.readSync(t.dataId) as Uint8Array[]), ['c', 'a', 'b']); + decodeStrings(backend.readSync(t.dataId) as Uint8Array[]), + ['c', 'a', 'b']); }); it('register string tensor with values and wrong shape throws error', () => { diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index cd890cc6c9..6ede9083bb 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -141,7 +141,8 @@ export function decodeWeights( const stringSpec = spec as StringWeightsManifestEntry; const bytes = new Uint8Array(buffer.slice(offset, offset + stringSpec.byteLength)); - values = ENV.platform.decodeUTF8(bytes).split(stringSpec.delimiter); + // TODO(smilkov): Use encoding from metadata. + values = ENV.platform.decode(bytes, 'utf-8').split(stringSpec.delimiter); offset += stringSpec.byteLength; } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; diff --git a/src/ops/tensor_ops.ts b/src/ops/tensor_ops.ts index 54d4027e40..e31c36a8b2 100644 --- a/src/ops/tensor_ops.ts +++ b/src/ops/tensor_ops.ts @@ -52,7 +52,8 @@ import {op} from './operation'; */ /** @doc {heading: 'Tensors', subheading: 'Creation'} */ function tensor( - values: TensorLike, shape?: ShapeMap[R], dtype?: DataType): Tensor { + values: TensorLike, shape?: ShapeMap[R], dtype?: DataType, + encoding?: string): Tensor { if (dtype == null) { dtype = inferDtype(values); } @@ -101,7 +102,8 @@ function tensor( values = dtype !== 'string' ? toTypedArray(values, dtype, ENV.getBool('DEBUG')) : flatten(values as string[]) as string[]; - return Tensor.make(shape, {values: values as TypedArray}, dtype); + return Tensor.make( + shape, {values: values as TypedArray}, dtype, null, encoding); } /** diff --git a/src/platforms/platform.ts b/src/platforms/platform.ts index 31037d1521..a9da7c1d20 100644 --- a/src/platforms/platform.ts +++ b/src/platforms/platform.ts @@ -31,6 +31,6 @@ export interface Platform { /** UTF-8 encode the provided string into an array of bytes. */ encodeUTF8(text: string): Uint8Array; - /** UTF-8 decode the provided bytes into a string. */ - decodeUTF8(bytes: Uint8Array): string; + /** Decode the provided bytes into a string using the provided encoding. */ + decode(bytes: Uint8Array, encoding: string): string; } diff --git a/src/platforms/platform_browser.ts b/src/platforms/platform_browser.ts index 41968373a0..757a2e488a 100644 --- a/src/platforms/platform_browser.ts +++ b/src/platforms/platform_browser.ts @@ -19,19 +19,18 @@ import {Platform} from './platform'; export class PlatformBrowser implements Platform { private textEncoder: TextEncoder; - private textDecoder: TextDecoder; constructor() { - // The built-in encoder and the decoder use UTF-8 encoding. + // According to the spec, the built-in encoder can do only UTF-8 encoding. + // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder this.textEncoder = new TextEncoder(); - this.textDecoder = new TextDecoder(); } encodeUTF8(text: string): Uint8Array { return this.textEncoder.encode(text); } - decodeUTF8(bytes: Uint8Array): string { - return this.textDecoder.decode(bytes); + decode(bytes: Uint8Array, encoding: string): string { + return new TextDecoder(encoding).decode(bytes); } fetch(path: string, init?: RequestInit): Promise { return fetch(path, init); diff --git a/src/platforms/platform_browser_test.ts b/src/platforms/platform_browser_test.ts index 74549fe260..1e36e3d3ed 100644 --- a/src/platforms/platform_browser_test.ts +++ b/src/platforms/platform_browser_test.ts @@ -53,25 +53,29 @@ describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => { [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); }); - it('decodeUTF8 single string', () => { + it('decode single string', () => { const platform = new PlatformBrowser(); - const s = platform.decodeUTF8(new Uint8Array([104, 101, 108, 108, 111])); + const s = + platform.decode(new Uint8Array([104, 101, 108, 108, 111]), 'utf-8'); expect(s.length).toBe(5); expect(s).toEqual('hello'); }); - it('decodeUTF8 two strings delimited', () => { + it('decode two strings delimited', () => { const platform = new PlatformBrowser(); - const s = platform.decodeUTF8( - new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + const s = platform.decode( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100]), + 'utf-8'); expect(s.length).toBe(11); expect(s).toEqual('hello\x00world'); }); - it('decodeUTF8 cyrillic', () => { + it('decode cyrillic', () => { const platform = new PlatformBrowser(); - const s = platform.decodeUTF8(new Uint8Array( - [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + const s = platform.decode( + new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190]), + 'utf-8'); expect(s.length).toBe(6); expect(s).toEqual('Здраво'); }); diff --git a/src/platforms/platform_node.ts b/src/platforms/platform_node.ts index e027ec8e77..ed93ae7463 100644 --- a/src/platforms/platform_node.ts +++ b/src/platforms/platform_node.ts @@ -28,24 +28,25 @@ export let systemFetch: (url: string, init?: RequestInit) => Promise; export class PlatformNode implements Platform { private textEncoder: TextEncoder; - private textDecoder: TextDecoder; + // tslint:disable-next-line:no-any + util: any; constructor() { - // tslint:disable-next-line: no-require-imports - const util = require('util'); - // The built-in encoder and the decoder use UTF-8 encoding. - this.textEncoder = new util.TextEncoder(); - this.textDecoder = new util.TextDecoder(); + // tslint:disable-next-line:no-require-imports + this.util = require('util'); + // According to the spec, the built-in encoder can do only UTF-8 encoding. + // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder + this.textEncoder = new this.util.TextEncoder(); } encodeUTF8(text: string): Uint8Array { return this.textEncoder.encode(text); } - decodeUTF8(bytes: Uint8Array): string { + decode(bytes: Uint8Array, encoding: string): string { if (bytes.length === 0) { return ''; } - return this.textDecoder.decode(bytes); + return new this.util.TextDecoder(encoding).decode(bytes); } fetch(path: string, requestInits?: RequestInit): Promise { if (ENV.global.fetch != null) { diff --git a/src/platforms/platform_node_test.ts b/src/platforms/platform_node_test.ts index bb9f1a438e..c9ee17a82c 100644 --- a/src/platforms/platform_node_test.ts +++ b/src/platforms/platform_node_test.ts @@ -91,25 +91,29 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => { [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); }); - it('decodeUTF8 single string', () => { + it('decode single string', () => { const platform = new PlatformNode(); - const s = platform.decodeUTF8(new Uint8Array([104, 101, 108, 108, 111])); + const s = + platform.decode(new Uint8Array([104, 101, 108, 108, 111]), 'utf8'); expect(s.length).toBe(5); expect(s).toEqual('hello'); }); - it('decodeUTF8 two strings delimited', () => { + it('decode two strings delimited', () => { const platform = new PlatformNode(); - const s = platform.decodeUTF8( - new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + const s = platform.decode( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100]), + 'utf8'); expect(s.length).toBe(11); expect(s).toEqual('hello\x00world'); }); - it('decodeUTF8 cyrillic', () => { + it('decode cyrillic', () => { const platform = new PlatformNode(); - const s = platform.decodeUTF8(new Uint8Array( - [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + const s = platform.decode( + new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190]), + 'utf8'); expect(s.length).toBe(6); expect(s).toEqual('Здраво'); }); diff --git a/src/tensor.ts b/src/tensor.ts index e5b5b7c63a..472f04c403 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -15,7 +15,6 @@ * ============================================================================= */ -import {ENV} from './environment'; import {tensorToString} from './tensor_format'; import {ArrayMap, BackendDataValues, DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D} from './types'; import * as util from './util'; @@ -439,6 +438,8 @@ export class Tensor { readonly dtype: DataType; /** The rank type for the array (see `Rank` enum). */ readonly rankType: R; + /** The encoding used to encode strings. Defined only for string tensors. */ + readonly encoding: string; /** Whether this tensor has been globally kept. */ kept = false; @@ -453,8 +454,8 @@ export class Tensor { readonly strides: number[]; protected constructor( - shape: ShapeMap[R], dtype: DataType, values?: DataValues|Uint8Array[], - dataId?: DataId, backend?: Backend) { + shape: ShapeMap[R], dtype: DataType, values?: BackendDataValues, + dataId?: DataId, backend?: Backend, encoding?: string) { this.shape = shape.slice() as ShapeMap[R]; this.dtype = dtype || 'float32'; this.size = util.sizeFromShape(shape); @@ -462,13 +463,10 @@ export class Tensor { this.dataId = dataId != null ? dataId : {}; this.id = trackerFn().nextTensorId(); this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher') as R; + this.encoding = encoding || 'utf-8'; trackerFn().registerTensor(this, backend); if (values != null) { - if (dtype === 'string' && util.isString(values) || - util.isString(values[0])) { - values = (values as string[]).map(v => ENV.platform.encodeUTF8(v)); - } - trackerFn().write(backend, this.dataId, values as BackendDataValues); + trackerFn().write(backend, this.dataId, values); } } @@ -478,9 +476,15 @@ export class Tensor { */ static make, D extends DataType = 'float32', R extends Rank = Rank>( - shape: ShapeMap[R], data: TensorData, dtype?: D, - backend?: Backend): T { - return new Tensor(shape, dtype, data.values, data.dataId, backend) as T; + shape: ShapeMap[R], data: TensorData, dtype?: D, backend?: Backend, + encoding?: string): T { + let backendVals = data.values as BackendDataValues; + if (data.values != null && dtype === 'string' && + util.isString(data.values[0])) { + backendVals = util.encodeStrings(data.values as string[]); + } + return new Tensor( + shape, dtype, backendVals, data.dataId, backend, encoding) as T; } /** Flatten a Tensor to a 1D array. */ @@ -618,7 +622,7 @@ export class Tensor { const data = trackerFn().read(this.dataId); if (this.dtype === 'string') { const bytes = await data as Uint8Array[]; - return bytes.map(d => ENV.platform.decodeUTF8(d)); + return util.decodeStrings(bytes, this.encoding); } return data as Promise; } @@ -632,7 +636,7 @@ export class Tensor { this.throwIfDisposed(); const data = trackerFn().readSync(this.dataId); if (this.dtype === 'string') { - return (data as Uint8Array[]).map(d => ENV.platform.decodeUTF8(d)); + return util.decodeStrings(data as Uint8Array[], this.encoding); } return data as DataTypeMap[D]; } diff --git a/src/tensor_test.ts b/src/tensor_test.ts index 3a19831875..e11e3878c1 100644 --- a/src/tensor_test.ts +++ b/src/tensor_test.ts @@ -20,6 +20,7 @@ import {ALL_ENVS, describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from './tensor'; import {expectArraysClose, expectArraysEqual, expectNumbersClose} from './test_util'; import {Rank} from './types'; +import {encodeStrings} from './util'; describeWithFlags('tensor', ALL_ENVS, () => { it('Tensors of arbitrary size', async () => { @@ -620,6 +621,14 @@ describeWithFlags('tensor', ALL_ENVS, () => { expectArraysEqual(await a.data(), ['даниел']); }); + it('default dtype from Uint8Array[]', async () => { + const bytes = encodeStrings(['a', 'b', 'c']); + const a = tf.tensor(bytes); + expect(a.dtype).toBe('string'); + expect(a.shape).toEqual([3]); + expectArraysEqual(await a.data(), ['a', 'b', 'c']); + }); + it('default dtype from empty string', async () => { const a = tf.tensor(''); expect(a.dtype).toBe('string'); diff --git a/src/util.ts b/src/util.ts index 2aff5a82db..3e781c7a5c 100644 --- a/src/util.ts +++ b/src/util.ts @@ -692,3 +692,21 @@ export function fetch( path: string, requestInits?: RequestInit): Promise { return ENV.platform.fetch(path, requestInits); } + +export function encodeString(s: string): Uint8Array { + return ENV.platform.encodeUTF8(s); +} + +export function encodeStrings( + strings: string[]): Uint8Array[] { + return strings.map(s => encodeString(s)); +} + +export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string { + return ENV.platform.decode(bytes, encoding); +} + +export function decodeStrings( + bytes: Uint8Array[], encoding = 'utf-8'): string[] { + return bytes.map(b => decodeString(b, encoding)); +} diff --git a/src/util_test.ts b/src/util_test.ts index 8b4d23be95..90fcf2e7bf 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -114,15 +114,14 @@ describe('util.flatten', () => { }); }); -function encode(data: string[]): Uint8Array[] { - return data.map(d => ENV.platform.encodeUTF8(d)); -} - describe('util.bytesFromStringArray', () => { it('count bytes after utf8 encoding', () => { - expect(util.bytesFromStringArray(encode(['a', 'bb', 'ccc']))).toBe(6); - expect(util.bytesFromStringArray(encode(['a', 'bb', 'cccddd']))).toBe(9); - expect(util.bytesFromStringArray(encode(['даниел']))).toBe(6 * 2); + expect(util.bytesFromStringArray(util.encodeStrings(['a', 'bb', 'ccc']))) + .toBe(6); + expect(util.bytesFromStringArray(util.encodeStrings(['a', 'bb', 'cccddd']))) + .toBe(9); + expect(util.bytesFromStringArray(util.encodeStrings(['даниел']))) + .toBe(6 * 2); }); }); From 26fadb5591cf733829fb9049115ec8cb221f21f5 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 22 Jun 2019 14:36:32 -0400 Subject: [PATCH 03/12] save --- src/tensor.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensor.ts b/src/tensor.ts index 472f04c403..8226e42a9c 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -16,7 +16,7 @@ */ import {tensorToString} from './tensor_format'; -import {ArrayMap, BackendDataValues, DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D} from './types'; +import {ArrayMap, BackendDataValues, DataType, DataTypeMap, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D} from './types'; import * as util from './util'; import {computeStrides, toNestedArray} from './util'; From db7f5b7423f15301d720f5b976f5e408becacec5 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Mon, 24 Jun 2019 15:09:52 -0400 Subject: [PATCH 04/12] save --- src/tensor_test.ts | 3 ++- src/util.ts | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tensor_test.ts b/src/tensor_test.ts index e11e3878c1..48e709304c 100644 --- a/src/tensor_test.ts +++ b/src/tensor_test.ts @@ -621,7 +621,8 @@ describeWithFlags('tensor', ALL_ENVS, () => { expectArraysEqual(await a.data(), ['даниел']); }); - it('default dtype from Uint8Array[]', async () => { + // tslint:disable-next-line: ban + fit('default dtype from Uint8Array[]', async () => { const bytes = encodeStrings(['a', 'b', 'c']); const a = tf.tensor(bytes); expect(a.dtype).toBe('string'); diff --git a/src/util.ts b/src/util.ts index 3e781c7a5c..3730681a1c 100644 --- a/src/util.ts +++ b/src/util.ts @@ -697,8 +697,7 @@ export function encodeString(s: string): Uint8Array { return ENV.platform.encodeUTF8(s); } -export function encodeStrings( - strings: string[]): Uint8Array[] { +export function encodeStrings(strings: string[]): Uint8Array[] { return strings.map(s => encodeString(s)); } From 26c9579b810904bf291e8f6c3cb465c4a2e1bc41 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Mon, 24 Jun 2019 18:33:40 -0400 Subject: [PATCH 05/12] save --- src/backends/cpu/backend_cpu.ts | 4 +-- src/backends/cpu/backend_cpu_test.ts | 10 +++++++- src/backends/webgl/backend_webgl.ts | 2 +- src/backends/webgl/backend_webgl_test.ts | 10 +++++++- src/io/io_utils.ts | 4 +-- src/io/io_utils_test.ts | 9 ++++--- src/ops/tensor_ops.ts | 6 ++--- src/platforms/platform.ts | 7 ++++-- src/platforms/platform_browser.ts | 6 ++++- src/platforms/platform_browser_test.ts | 6 ++--- src/platforms/platform_node.ts | 6 ++++- src/platforms/platform_node_test.ts | 6 ++--- src/tensor.ts | 31 +++++++++++++++--------- src/tensor_test.ts | 19 +++++++-------- src/types.ts | 24 +++++++++--------- src/util.ts | 15 +++--------- src/util_test.ts | 11 ++++++--- 17 files changed, 103 insertions(+), 73 deletions(-) diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index 81ba3d5c62..798de03762 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -36,7 +36,7 @@ import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../o import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; import {BackendDataValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; -import {decodeStrings, getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; +import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; @@ -206,7 +206,7 @@ export class MathBackendCPU implements KernelBackend { let decodedData = data as DataValues; if (t.dtype === 'string') { // Decode the bytes into string. - decodedData = decodeStrings(data as Uint8Array[], t.encoding); + decodedData = (data as Uint8Array[]).map(d => util.decodeString(d)); } return buffer(t.shape, t.dtype, decodedData) as TensorBuffer; } diff --git a/src/backends/cpu/backend_cpu_test.ts b/src/backends/cpu/backend_cpu_test.ts index 77e48fc5d4..85324a915b 100644 --- a/src/backends/cpu/backend_cpu_test.ts +++ b/src/backends/cpu/backend_cpu_test.ts @@ -19,11 +19,19 @@ import * as tf from '../../index'; import {describeWithFlags} from '../../jasmine_util'; import {tensor2d} from '../../ops/ops'; import {expectArraysClose, expectArraysEqual} from '../../test_util'; -import {decodeStrings, encodeStrings} from '../../util'; +import {decodeString, encodeString} from '../../util'; import {MathBackendCPU} from './backend_cpu'; import {CPU_ENVS} from './backend_cpu_test_registry'; +function encodeStrings(a: string[]): Uint8Array[] { + return a.map(s => encodeString(s)); +} + +function decodeStrings(bytes: Uint8Array[]): string[] { + return bytes.map(b => decodeString(b)); +} + describeWithFlags('backendCPU', CPU_ENVS, () => { let backend: MathBackendCPU; beforeEach(() => { diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index 76d1bca189..00ef5cae23 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -953,7 +953,7 @@ export class MathBackendWebGL implements KernelBackend { tile(x: T, reps: number[]): T { if (x.dtype === 'string') { const data = this.readSync(x.dataId) as Uint8Array[]; - const decodedData = util.decodeStrings(data, x.encoding); + const decodedData = data.map(d => util.decodeString(d)); const buf = buffer(x.shape, x.dtype, decodedData); return tile(buf, reps) as T; } diff --git a/src/backends/webgl/backend_webgl_test.ts b/src/backends/webgl/backend_webgl_test.ts index 95e33be99e..5d5bf06842 100644 --- a/src/backends/webgl/backend_webgl_test.ts +++ b/src/backends/webgl/backend_webgl_test.ts @@ -18,11 +18,19 @@ import * as tf from '../../index'; import {describeWithFlags} from '../../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../../test_util'; -import {decodeStrings, encodeStrings} from '../../util'; +import {decodeString, encodeString} from '../../util'; import {MathBackendWebGL, WebGLMemoryInfo} from './backend_webgl'; import {WEBGL_ENVS} from './backend_webgl_test_registry'; +function encodeStrings(a: string[]): Uint8Array[] { + return a.map(s => encodeString(s)); +} + +function decodeStrings(bytes: Uint8Array[]): string[] { + return bytes.map(b => decodeString(b)); +} + describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => { let webglLazilyUnpackFlagSaved: boolean; let webglCpuForwardFlagSaved: boolean; diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 6ede9083bb..40f856248f 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -19,7 +19,7 @@ import {ENV} from '../environment'; import {tensor} from '../ops/tensor_ops'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; -import {sizeFromShape} from '../util'; +import {encodeString, sizeFromShape} from '../util'; import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, StringWeightsManifestEntry, WeightGroup, WeightsManifestEntry} from './types'; @@ -66,7 +66,7 @@ export async function encodeWeights( const utf8bytes = new Promise(async resolve => { const stringSpec = spec as StringWeightsManifestEntry; const data = await t.data(); - const bytes = ENV.platform.encodeUTF8(data.join(STRING_DELIMITER)); + const bytes = encodeString(data.join(STRING_DELIMITER)); stringSpec.byteLength = bytes.length; stringSpec.delimiter = STRING_DELIMITER; resolve(bytes); diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 0a56f48bd2..60dff7db88 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -351,20 +351,21 @@ describe('encodeWeights', () => { x5ByteLength); let delim = specs[0].delimiter; expect(new Uint8Array(data, 0, x1ByteLength)) - .toEqual(tf.ENV.platform.encodeUTF8(`a${delim}bc${delim}def${delim}g`)); + .toEqual( + tf.ENV.platform.encode(`a${delim}bc${delim}def${delim}g`, 'utf-8')); // The middle string takes up 0 bytes. delim = specs[2].delimiter; expect(new Uint8Array(data, x1ByteLength + x2ByteLength, x3ByteLength)) - .toEqual(tf.ENV.platform.encodeUTF8(`здраво${delim}поздрав`)); + .toEqual(tf.ENV.platform.encode(`здраво${delim}поздрав`, 'utf-8')); delim = specs[3].delimiter; expect(new Uint8Array( data, x1ByteLength + x2ByteLength + x3ByteLength, x4ByteLength)) - .toEqual(tf.ENV.platform.encodeUTF8('正常')); + .toEqual(tf.ENV.platform.encode('正常', 'utf-8')); delim = specs[4].delimiter; expect(new Uint8Array( data, x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength, x5ByteLength)) - .toEqual(tf.ENV.platform.encodeUTF8('hello')); + .toEqual(tf.ENV.platform.encode('hello', 'utf-8')); expect(specs).toEqual([ { name: 'x1', diff --git a/src/ops/tensor_ops.ts b/src/ops/tensor_ops.ts index e31c36a8b2..54d4027e40 100644 --- a/src/ops/tensor_ops.ts +++ b/src/ops/tensor_ops.ts @@ -52,8 +52,7 @@ import {op} from './operation'; */ /** @doc {heading: 'Tensors', subheading: 'Creation'} */ function tensor( - values: TensorLike, shape?: ShapeMap[R], dtype?: DataType, - encoding?: string): Tensor { + values: TensorLike, shape?: ShapeMap[R], dtype?: DataType): Tensor { if (dtype == null) { dtype = inferDtype(values); } @@ -102,8 +101,7 @@ function tensor( values = dtype !== 'string' ? toTypedArray(values, dtype, ENV.getBool('DEBUG')) : flatten(values as string[]) as string[]; - return Tensor.make( - shape, {values: values as TypedArray}, dtype, null, encoding); + return Tensor.make(shape, {values: values as TypedArray}, dtype); } /** diff --git a/src/platforms/platform.ts b/src/platforms/platform.ts index a9da7c1d20..98aaa396d9 100644 --- a/src/platforms/platform.ts +++ b/src/platforms/platform.ts @@ -29,8 +29,11 @@ export interface Platform { */ fetch(path: string, requestInits?: RequestInit): Promise; - /** UTF-8 encode the provided string into an array of bytes. */ - encodeUTF8(text: string): Uint8Array; + /** + * Encode the provided string into an array of bytes using the provided + * encoding. + */ + encode(text: string, encoding: string): Uint8Array; /** Decode the provided bytes into a string using the provided encoding. */ decode(bytes: Uint8Array, encoding: string): string; } diff --git a/src/platforms/platform_browser.ts b/src/platforms/platform_browser.ts index 757a2e488a..224e0c87d7 100644 --- a/src/platforms/platform_browser.ts +++ b/src/platforms/platform_browser.ts @@ -26,7 +26,11 @@ export class PlatformBrowser implements Platform { this.textEncoder = new TextEncoder(); } - encodeUTF8(text: string): Uint8Array { + encode(text: string, encoding: string): Uint8Array { + if (encoding !== 'utf-8' && encoding !== 'utf8') { + throw new Error( + `Browser's encoder only supports utf-8, but got ${encoding}`); + } return this.textEncoder.encode(text); } decode(bytes: Uint8Array, encoding: string): string { diff --git a/src/platforms/platform_browser_test.ts b/src/platforms/platform_browser_test.ts index 1e36e3d3ed..54976f7ef4 100644 --- a/src/platforms/platform_browser_test.ts +++ b/src/platforms/platform_browser_test.ts @@ -32,14 +32,14 @@ describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => { it('encodeUTF8 single string', () => { const platform = new PlatformBrowser(); - const bytes = platform.encodeUTF8('hello'); + const bytes = platform.encode('hello', 'utf-8'); expect(bytes.length).toBe(5); expect(bytes).toEqual(new Uint8Array([104, 101, 108, 108, 111])); }); it('encodeUTF8 two strings delimited', () => { const platform = new PlatformBrowser(); - const bytes = platform.encodeUTF8('hello\x00world'); + const bytes = platform.encode('hello\x00world', 'utf-8'); expect(bytes.length).toBe(11); expect(bytes).toEqual( new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); @@ -47,7 +47,7 @@ describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => { it('encodeUTF8 cyrillic', () => { const platform = new PlatformBrowser(); - const bytes = platform.encodeUTF8('Здраво'); + const bytes = platform.encode('Здраво', 'utf-8'); expect(bytes.length).toBe(12); expect(bytes).toEqual(new Uint8Array( [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); diff --git a/src/platforms/platform_node.ts b/src/platforms/platform_node.ts index ed93ae7463..07af5f8401 100644 --- a/src/platforms/platform_node.ts +++ b/src/platforms/platform_node.ts @@ -39,7 +39,11 @@ export class PlatformNode implements Platform { this.textEncoder = new this.util.TextEncoder(); } - encodeUTF8(text: string): Uint8Array { + encode(text: string, encoding: string): Uint8Array { + if (encoding !== 'utf-8' && encoding !== 'utf8') { + throw new Error( + `Node built-in encoder only supports utf-8, but got ${encoding}`); + } return this.textEncoder.encode(text); } decode(bytes: Uint8Array, encoding: string): string { diff --git a/src/platforms/platform_node_test.ts b/src/platforms/platform_node_test.ts index c9ee17a82c..5bc6825213 100644 --- a/src/platforms/platform_node_test.ts +++ b/src/platforms/platform_node_test.ts @@ -70,14 +70,14 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => { it('encodeUTF8 single string', () => { const platform = new PlatformNode(); - const bytes = platform.encodeUTF8('hello'); + const bytes = platform.encode('hello', 'utf-8'); expect(bytes.length).toBe(5); expect(bytes).toEqual(new Uint8Array([104, 101, 108, 108, 111])); }); it('encodeUTF8 two strings delimited', () => { const platform = new PlatformNode(); - const bytes = platform.encodeUTF8('hello\x00world'); + const bytes = platform.encode('hello\x00world', 'utf-8'); expect(bytes.length).toBe(11); expect(bytes).toEqual( new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); @@ -85,7 +85,7 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => { it('encodeUTF8 cyrillic', () => { const platform = new PlatformNode(); - const bytes = platform.encodeUTF8('Здраво'); + const bytes = platform.encode('Здраво', 'utf-8'); expect(bytes.length).toBe(12); expect(bytes).toEqual(new Uint8Array( [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); diff --git a/src/tensor.ts b/src/tensor.ts index 8226e42a9c..6e093029e0 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -16,7 +16,7 @@ */ import {tensorToString} from './tensor_format'; -import {ArrayMap, BackendDataValues, DataType, DataTypeMap, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D} from './types'; +import {ArrayMap, BackendDataValues, DataType, DataTypeMap, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D, TypedArray} from './types'; import * as util from './util'; import {computeStrides, toNestedArray} from './util'; @@ -438,8 +438,6 @@ export class Tensor { readonly dtype: DataType; /** The rank type for the array (see `Rank` enum). */ readonly rankType: R; - /** The encoding used to encode strings. Defined only for string tensors. */ - readonly encoding: string; /** Whether this tensor has been globally kept. */ kept = false; @@ -455,7 +453,7 @@ export class Tensor { protected constructor( shape: ShapeMap[R], dtype: DataType, values?: BackendDataValues, - dataId?: DataId, backend?: Backend, encoding?: string) { + dataId?: DataId, backend?: Backend) { this.shape = shape.slice() as ShapeMap[R]; this.dtype = dtype || 'float32'; this.size = util.sizeFromShape(shape); @@ -463,7 +461,6 @@ export class Tensor { this.dataId = dataId != null ? dataId : {}; this.id = trackerFn().nextTensorId(); this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher') as R; - this.encoding = encoding || 'utf-8'; trackerFn().registerTensor(this, backend); if (values != null) { trackerFn().write(backend, this.dataId, values); @@ -476,15 +473,14 @@ export class Tensor { */ static make, D extends DataType = 'float32', R extends Rank = Rank>( - shape: ShapeMap[R], data: TensorData, dtype?: D, backend?: Backend, - encoding?: string): T { + shape: ShapeMap[R], data: TensorData, dtype?: D, + backend?: Backend): T { let backendVals = data.values as BackendDataValues; if (data.values != null && dtype === 'string' && util.isString(data.values[0])) { - backendVals = util.encodeStrings(data.values as string[]); + backendVals = (data.values as string[]).map(d => util.encodeString(d)); } - return new Tensor( - shape, dtype, backendVals, data.dataId, backend, encoding) as T; + return new Tensor(shape, dtype, backendVals, data.dataId, backend) as T; } /** Flatten a Tensor to a 1D array. */ @@ -622,7 +618,7 @@ export class Tensor { const data = trackerFn().read(this.dataId); if (this.dtype === 'string') { const bytes = await data as Uint8Array[]; - return util.decodeStrings(bytes, this.encoding); + return bytes.map(b => util.decodeString(b)); } return data as Promise; } @@ -636,11 +632,22 @@ export class Tensor { this.throwIfDisposed(); const data = trackerFn().readSync(this.dataId); if (this.dtype === 'string') { - return util.decodeStrings(data as Uint8Array[], this.encoding); + return (data as Uint8Array[]).map(b => util.decodeString(b)); } return data as DataTypeMap[D]; } + /** Returns the underlying bytes of the tensor's data. */ + async bytes(): Promise { + this.throwIfDisposed(); + const data = await trackerFn().read(this.dataId); + if (this.dtype === 'string') { + return data as Uint8Array[]; + } else { + return new Uint8Array((data as TypedArray).buffer); + } + } + /** * Disposes `tf.Tensor` from memory. */ diff --git a/src/tensor_test.ts b/src/tensor_test.ts index 48e709304c..71a33d6d80 100644 --- a/src/tensor_test.ts +++ b/src/tensor_test.ts @@ -20,7 +20,6 @@ import {ALL_ENVS, describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from './tensor'; import {expectArraysClose, expectArraysEqual, expectNumbersClose} from './test_util'; import {Rank} from './types'; -import {encodeStrings} from './util'; describeWithFlags('tensor', ALL_ENVS, () => { it('Tensors of arbitrary size', async () => { @@ -267,6 +266,15 @@ describeWithFlags('tensor', ALL_ENVS, () => { expectArraysEqual(await a.data(), ['aa', 'bb', 'cc']); }); + // tslint:disable-next-line: ban + fit('tf.tensor1d() from encoded string', async () => { + const bytes = ['aa', 'bb', 'cc'].map(s => tf.util.encodeString(s)); + const a = tf.tensor1d(bytes, 'string'); + expect(a.dtype).toBe('string'); + expect(a.shape).toEqual([3]); + expectArraysEqual(await a.data(), ['aa', 'bb', 'cc']); + }); + it('tf.tensor1d() from number[][], shape mismatch', () => { // tslint:disable-next-line:no-any expect(() => tf.tensor1d([[1], [2], [3]] as any)).toThrowError(); @@ -621,15 +629,6 @@ describeWithFlags('tensor', ALL_ENVS, () => { expectArraysEqual(await a.data(), ['даниел']); }); - // tslint:disable-next-line: ban - fit('default dtype from Uint8Array[]', async () => { - const bytes = encodeStrings(['a', 'b', 'c']); - const a = tf.tensor(bytes); - expect(a.dtype).toBe('string'); - expect(a.shape).toEqual([3]); - expectArraysEqual(await a.data(), ['a', 'b', 'c']); - }); - it('default dtype from empty string', async () => { const a = tf.tensor(''); expect(a.dtype).toBe('string'); diff --git a/src/types.ts b/src/types.ts index c6fb3641f3..6c48f8ead2 100644 --- a/src/types.ts +++ b/src/types.ts @@ -134,25 +134,27 @@ export function sumOutType(type: DataType): DataType { /** @docalias TypedArray|Array */ export type TensorLike = TypedArray|number|boolean|string|RecursiveArray| - RecursiveArray|RecursiveArray; -export type ScalarLike = number|boolean|string; + RecursiveArray|RecursiveArray|Uint8Array[]; +export type ScalarLike = number|boolean|string|Uint8Array; /** @docalias TypedArray|Array */ -export type TensorLike1D = TypedArray|number[]|boolean[]|string[]; +export type TensorLike1D = TypedArray|number[]|boolean[]|string[]|Uint8Array[]; /** @docalias TypedArray|Array */ -export type TensorLike2D = - TypedArray|number[]|number[][]|boolean[]|boolean[][]|string[]|string[][]; +export type TensorLike2D = TypedArray|number[]|number[][]|boolean[]|boolean[][]| + string[]|string[][]|Uint8Array[]|Uint8Array[][]; /** @docalias TypedArray|Array */ export type TensorLike3D = TypedArray|number[]|number[][][]|boolean[]| - boolean[][][]|string[]|string[][][]; + boolean[][][]|string[]|string[][][]|Uint8Array[]|Uint8Array[][][]; /** @docalias TypedArray|Array */ export type TensorLike4D = TypedArray|number[]|number[][][][]|boolean[]| - boolean[][][][]|string[]|string[][][][]; + boolean[][][][]|string[]|string[][][][]|Uint8Array[]|Uint8Array[][][][]; /** @docalias TypedArray|Array */ -export type TensorLike5D = TypedArray|number[]|number[][][][][]|boolean[]| - boolean[][][][][]|string[]|string[][][][][]; +export type TensorLike5D = + TypedArray|number[]|number[][][][][]|boolean[]|boolean[][][][][]|string[]| + string[][][][][]|Uint8Array[]|Uint8Array[][][][][]; /** @docalias TypedArray|Array */ -export type TensorLike6D = TypedArray|number[]|number[][][][][][]|boolean[]| - boolean[][][][][][]|string[]|string[][][][][][]; +export type TensorLike6D = + TypedArray|number[]|number[][][][][][]|boolean[]|boolean[][][][][][]| + string[]|string[][][][][][]|Uint8Array[]|Uint8Array[][][][][]; /** Type for representing image dat in Uint8Array type. */ export interface PixelData { diff --git a/src/util.ts b/src/util.ts index 3730681a1c..d0390ed033 100644 --- a/src/util.ts +++ b/src/util.ts @@ -693,19 +693,12 @@ export function fetch( return ENV.platform.fetch(path, requestInits); } -export function encodeString(s: string): Uint8Array { - return ENV.platform.encodeUTF8(s); -} - -export function encodeStrings(strings: string[]): Uint8Array[] { - return strings.map(s => encodeString(s)); +export function encodeString(s: string, encoding = 'utf-8'): Uint8Array { + encoding = encoding || 'utf-8'; + return ENV.platform.encode(s, encoding); } export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string { + encoding = encoding || 'utf-8'; return ENV.platform.decode(bytes, encoding); } - -export function decodeStrings( - bytes: Uint8Array[], encoding = 'utf-8'): string[] { - return bytes.map(b => decodeString(b, encoding)); -} diff --git a/src/util_test.ts b/src/util_test.ts index 90fcf2e7bf..573acb879a 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -114,14 +114,17 @@ describe('util.flatten', () => { }); }); +function encodeStrings(a: string[]): Uint8Array[] { + return a.map(s => util.encodeString(s)); +} + describe('util.bytesFromStringArray', () => { it('count bytes after utf8 encoding', () => { - expect(util.bytesFromStringArray(util.encodeStrings(['a', 'bb', 'ccc']))) + expect(util.bytesFromStringArray(encodeStrings(['a', 'bb', 'ccc']))) .toBe(6); - expect(util.bytesFromStringArray(util.encodeStrings(['a', 'bb', 'cccddd']))) + expect(util.bytesFromStringArray(encodeStrings(['a', 'bb', 'cccddd']))) .toBe(9); - expect(util.bytesFromStringArray(util.encodeStrings(['даниел']))) - .toBe(6 * 2); + expect(util.bytesFromStringArray(encodeStrings(['даниел']))).toBe(6 * 2); }); }); From d3f4cce5d2a062ff31870bd6994acd2694e2f811 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 25 Jun 2019 09:53:59 -0400 Subject: [PATCH 06/12] save --- src/ops/tensor_ops.ts | 9 +++++++-- src/tensor_util_env.ts | 2 +- src/util.ts | 6 ++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/ops/tensor_ops.ts b/src/ops/tensor_ops.ts index 54d4027e40..325d132295 100644 --- a/src/ops/tensor_ops.ts +++ b/src/ops/tensor_ops.ts @@ -100,7 +100,7 @@ function tensor( shape = shape || inferredShape; values = dtype !== 'string' ? toTypedArray(values, dtype, ENV.getBool('DEBUG')) : - flatten(values as string[]) as string[]; + flatten(values as string[], [], true) as string[]; return Tensor.make(shape, {values: values as TypedArray}, dtype); } @@ -145,9 +145,14 @@ function scalar(value: number|boolean|string, dtype?: DataType): Scalar { function tensor1d(values: TensorLike1D, dtype?: DataType): Tensor1D { assertNonNull(values); const inferredShape = inferShape(values); - if (inferredShape.length !== 1) { + if (dtype !== 'string' && inferredShape.length !== 1) { throw new Error('tensor1d() requires values to be a flat/TypedArray'); } + if (dtype === 'string' && inferredShape.length !== 2) { + throw new Error( + 'tensor1d() of dtype string requires values to be ' + + 'string[] or Uint8Array[]'); + } return tensor(values, inferredShape as [number], dtype); } diff --git a/src/tensor_util_env.ts b/src/tensor_util_env.ts index ad5cb6b881..7ea293b2d3 100644 --- a/src/tensor_util_env.ts +++ b/src/tensor_util_env.ts @@ -110,7 +110,7 @@ export function convertToTensor( } const values = inferredDtype !== 'string' ? toTypedArray(x, inferredDtype as DataType, ENV.getBool('DEBUG')) : - flatten(x as string[]) as string[]; + flatten(x as string[], [], true) as string[]; return Tensor.make(inferredShape, {values}, inferredDtype); } diff --git a/src/util.ts b/src/util.ts index d0390ed033..f8defc0d9e 100644 --- a/src/util.ts +++ b/src/util.ts @@ -135,11 +135,13 @@ export function assertNonNull(a: TensorLike): void { * * @param arr The nested array to flatten. * @param result The destination array which holds the elements. + * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults + * to false. */ /** @doc {heading: 'Util', namespace: 'util'} */ export function flatten|TypedArray>( - arr: T|RecursiveArray, result: T[] = []): T[] { + arr: T|RecursiveArray, result: T[] = [], skipTypedArray = false): T[] { if (result == null) { result = []; } @@ -551,7 +553,7 @@ export function toTypedArray( throw new Error('Cannot convert a string[] to a TypedArray'); } if (Array.isArray(a)) { - a = flatten(a as number[]); + a = flatten(a); } if (debugMode) { checkConversionForErrors(a as number[], dtype); From 2238f122887d0ec85c4cfcbd6073551a43180738 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 25 Jun 2019 12:59:15 -0400 Subject: [PATCH 07/12] save --- src/ops/tensor_ops.ts | 33 +++++---- src/tensor_test.ts | 152 +++++++++++++++++++++++++++++++++++++++-- src/tensor_util_env.ts | 7 +- src/util.ts | 4 +- src/util_test.ts | 32 +++++++++ 5 files changed, 204 insertions(+), 24 deletions(-) diff --git a/src/ops/tensor_ops.ts b/src/ops/tensor_ops.ts index 93c3e25dcd..0ed49eac5b 100644 --- a/src/ops/tensor_ops.ts +++ b/src/ops/tensor_ops.ts @@ -53,7 +53,7 @@ import {op} from './operation'; /** @doc {heading: 'Tensors', subheading: 'Creation'} */ function tensor( values: TensorLike, shape?: ShapeMap[R], dtype?: DataType): Tensor { - const inferredShape = inferShape(values); + const inferredShape = inferShape(values, dtype); return makeTensor(values, shape, inferredShape, dtype) as Tensor; } @@ -125,12 +125,20 @@ function makeTensor( * @param dtype The data type. */ /** @doc {heading: 'Tensors', subheading: 'Creation'} */ -function scalar(value: number|boolean|string, dtype?: DataType): Scalar { - if ((isTypedArray(value) || Array.isArray(value)) && dtype !== 'complex64') { +function scalar( + value: number|boolean|string|Uint8Array, dtype?: DataType): Scalar { + if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) && + dtype !== 'complex64') { throw new Error( 'Error creating a new Scalar: value must be a primitive ' + '(number|boolean|string)'); } + if (dtype === 'string' && isTypedArray(value) && + !(value instanceof Uint8Array)) { + throw new Error( + 'When making a scalar from encoded string, ' + + 'the value must be Uint8Array'); + } const shape: number[] = []; const inferredShape: number[] = []; return makeTensor(value, shape, inferredShape, dtype) as Scalar; @@ -153,15 +161,10 @@ function scalar(value: number|boolean|string, dtype?: DataType): Scalar { /** @doc {heading: 'Tensors', subheading: 'Creation'} */ function tensor1d(values: TensorLike1D, dtype?: DataType): Tensor1D { assertNonNull(values); - const inferredShape = inferShape(values); - if (dtype !== 'string' && inferredShape.length !== 1) { + const inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 1) { throw new Error('tensor1d() requires values to be a flat/TypedArray'); } - if (dtype === 'string' && inferredShape.length !== 2) { - throw new Error( - 'tensor1d() of dtype string requires values to be ' + - 'string[] or Uint8Array[]'); - } const shape: number[] = null; return makeTensor(values, shape, inferredShape, dtype) as Tensor1D; } @@ -195,7 +198,7 @@ function tensor2d( if (shape != null && shape.length !== 2) { throw new Error('tensor2d() requires shape to have two numbers'); } - const inferredShape = inferShape(values); + const inferredShape = inferShape(values, dtype); if (inferredShape.length !== 2 && inferredShape.length !== 1) { throw new Error( 'tensor2d() requires values to be number[][] or flat/TypedArray'); @@ -237,7 +240,7 @@ function tensor3d( if (shape != null && shape.length !== 3) { throw new Error('tensor3d() requires shape to have three numbers'); } - const inferredShape = inferShape(values); + const inferredShape = inferShape(values, dtype); if (inferredShape.length !== 3 && inferredShape.length !== 1) { throw new Error( 'tensor3d() requires values to be number[][][] or flat/TypedArray'); @@ -279,7 +282,7 @@ function tensor4d( if (shape != null && shape.length !== 4) { throw new Error('tensor4d() requires shape to have four numbers'); } - const inferredShape = inferShape(values); + const inferredShape = inferShape(values, dtype); if (inferredShape.length !== 4 && inferredShape.length !== 1) { throw new Error( 'tensor4d() requires values to be number[][][][] or flat/TypedArray'); @@ -321,7 +324,7 @@ function tensor5d( if (shape != null && shape.length !== 5) { throw new Error('tensor5d() requires shape to have five numbers'); } - const inferredShape = inferShape(values); + const inferredShape = inferShape(values, dtype); if (inferredShape.length !== 5 && inferredShape.length !== 1) { throw new Error( 'tensor5d() requires values to be ' + @@ -365,7 +368,7 @@ function tensor6d( if (shape != null && shape.length !== 6) { throw new Error('tensor6d() requires shape to have six numbers'); } - const inferredShape = inferShape(values); + const inferredShape = inferShape(values, dtype); if (inferredShape.length !== 6 && inferredShape.length !== 1) { throw new Error( 'tensor6d() requires values to be number[][][][][][] or ' + diff --git a/src/tensor_test.ts b/src/tensor_test.ts index 44b6cdbb0e..8f3edd8137 100644 --- a/src/tensor_test.ts +++ b/src/tensor_test.ts @@ -19,7 +19,21 @@ import * as tf from './index'; import {ALL_ENVS, describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from './tensor'; import {expectArraysClose, expectArraysEqual, expectNumbersClose} from './test_util'; -import {Rank} from './types'; +import {Rank, RecursiveArray, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TypedArray} from './types'; +import {encodeString} from './util'; + +/** Util method used for tests. It encodes each string into utf-8 bytes. */ +function encodeStrings(a: RecursiveArray<{}>): RecursiveArray { + for (let i = 0; i < (a as Array<{}>).length; i++) { + const val = a[i]; + if (Array.isArray(val)) { + encodeStrings(val); + } else { + a[i] = encodeString(val as string); + } + } + return a; +} describeWithFlags('tensor', ALL_ENVS, () => { it('Tensors of arbitrary size', async () => { @@ -266,15 +280,26 @@ describeWithFlags('tensor', ALL_ENVS, () => { expectArraysEqual(await a.data(), ['aa', 'bb', 'cc']); }); - // tslint:disable-next-line: ban - fit('tf.tensor1d() from encoded string', async () => { - const bytes = ['aa', 'bb', 'cc'].map(s => tf.util.encodeString(s)); + it('tf.tensor1d() from encoded strings', async () => { + const bytes = encodeStrings(['aa', 'bb', 'cc']) as TensorLike1D; const a = tf.tensor1d(bytes, 'string'); expect(a.dtype).toBe('string'); expect(a.shape).toEqual([3]); expectArraysEqual(await a.data(), ['aa', 'bb', 'cc']); }); + it('tf.tensor1d() from encoded strings without dtype errors', async () => { + // We do not want to infer 'string' when the user passes Uint8Array in order + // to be forward compatible in the future when we add uint8 dtype. + const bytes = encodeStrings(['aa', 'bb', 'cc']) as TensorLike1D; + expect(() => tf.tensor1d(bytes)).toThrowError(); + }); + + it('tf.tensor1d() from encoded strings, shape mismatch', () => { + const bytes = encodeStrings([['aa'], ['bb'], ['cc']]) as TensorLike1D; + expect(() => tf.tensor1d(bytes)).toThrowError(); + }); + it('tf.tensor1d() from number[][], shape mismatch', () => { // tslint:disable-next-line:no-any expect(() => tf.tensor1d([[1], [2], [3]] as any)).toThrowError(); @@ -297,6 +322,26 @@ describeWithFlags('tensor', ALL_ENVS, () => { expectArraysEqual(await a.data(), ['aa', 'bb', 'cc', 'dd']); }); + it('tf.tensor2d() from encoded strings', async () => { + const bytes = encodeStrings([['aa', 'bb'], ['cc', 'dd']]) as TensorLike2D; + const a = tf.tensor2d(bytes, [2, 2], 'string'); + expect(a.dtype).toBe('string'); + expect(a.shape).toEqual([2, 2]); + expectArraysEqual(await a.data(), ['aa', 'bb', 'cc', 'dd']); + }); + + it('tf.tensor2d() from encoded strings without dtype errors', async () => { + // We do not want to infer 'string' when the user passes Uint8Array in order + // to be forward compatible in the future when we add uint8 dtype. + const bytes = encodeStrings([['aa', 'bb'], ['cc', 'dd']]) as TensorLike2D; + expect(() => tf.tensor2d(bytes)).toThrowError(); + }); + + it('tf.tensor2d() from encoded strings, shape mismatch', () => { + const bytes = encodeStrings([['aa', 'bb'], ['cc', 'dd']]) as TensorLike2D; + expect(() => tf.tensor2d(bytes, [3, 2], 'string')).toThrowError(); + }); + it('tf.tensor2d() requires shape to be of length 2', () => { // tslint:disable-next-line:no-any const shape: any = [4]; @@ -342,6 +387,28 @@ describeWithFlags('tensor', ALL_ENVS, () => { expectArraysEqual(await a.data(), ['a', 'b', 'c', 'd', 'e', 'f']); }); + it('tf.tensor3d() from encoded strings', async () => { + const bytes = encodeStrings([[['a'], ['b'], ['c']], [['d'], ['e'], ['f']]]); + const a = tf.tensor3d(bytes as TensorLike3D, [2, 3, 1], 'string'); + expect(a.dtype).toBe('string'); + expect(a.shape).toEqual([2, 3, 1]); + expectArraysEqual(await a.data(), ['a', 'b', 'c', 'd', 'e', 'f']); + }); + + it('tf.tensor3d() from encoded strings without dtype errors', async () => { + // We do not want to infer 'string' when the user passes Uint8Array in order + // to be forward compatible in the future when we add uint8 dtype. + const bytes = encodeStrings([[['a'], ['b'], ['c']], [['d'], ['e'], ['f']]]); + expect(() => tf.tensor3d(bytes as TensorLike3D)).toThrowError(); + }); + + it('tf.tensor3d() from encoded strings, shape mismatch', () => { + const bytes = encodeStrings([[['a'], ['b'], ['c']], [['d'], ['e'], ['f']]]); + // Actual shape is [2, 3, 1]. + expect(() => tf.tensor3d(bytes as TensorLike3D, [3, 2, 1], 'string')) + .toThrowError(); + }); + it('tensor3d() from number[][][], but shape does not match', () => { const values = [[[1], [2], [3]], [[4], [5], [6]]]; // Actual shape is [2, 3, 1]. @@ -378,6 +445,28 @@ describeWithFlags('tensor', ALL_ENVS, () => { expectArraysEqual(await a.data(), ['a', 'b', 'c', 'd']); }); + it('tf.tensor4d() from encoded strings', async () => { + const bytes = encodeStrings([[[['a']], [['b']]], [[['c']], [['d']]]]); + const a = tf.tensor4d(bytes as TensorLike4D, [2, 2, 1, 1], 'string'); + expect(a.dtype).toBe('string'); + expect(a.shape).toEqual([2, 2, 1, 1]); + expectArraysEqual(await a.data(), ['a', 'b', 'c', 'd']); + }); + + it('tf.tensor4d() from encoded strings without dtype errors', async () => { + // We do not want to infer 'string' when the user passes Uint8Array in order + // to be forward compatible in the future when we add uint8 dtype. + const bytes = encodeStrings([[[['a']], [['b']]], [[['c']], [['d']]]]); + expect(() => tf.tensor4d(bytes as TensorLike4D)).toThrowError(); + }); + + it('tf.tensor4d() from encoded strings, shape mismatch', () => { + const bytes = encodeStrings([[[['a']], [['b']]], [[['c']], [['d']]]]); + // Actual shape is [2, 2, 1. 1]. + expect(() => tf.tensor4d(bytes as TensorLike4D, [2, 1, 2, 1], 'string')) + .toThrowError(); + }); + it('tensor4d() from string[][][][] infer shape', async () => { const vals = [[[['a']], [['b']]], [[['c']], [['d']]]]; const a = tf.tensor4d(vals); @@ -884,6 +973,24 @@ describeWithFlags('tensor', ALL_ENVS, () => { expect(b.shape).toEqual([1, 1]); }); + it('scalar from encoded string', async () => { + const a = tf.scalar(encodeString('hello'), 'string'); + expect(a.dtype).toBe('string'); + expect(a.shape).toEqual([]); + expectArraysEqual(await a.data(), ['hello']); + }); + + it('scalar from encoded string, but missing dtype', async () => { + // We do not want to infer 'string' when the user passes Uint8Array in order + // to be forward compatible in the future when we add uint8 dtype. + expect(() => tf.scalar(encodeString('hello'))).toThrowError(); + }); + + it('scalar from encoded string, but value is not uint8array', async () => { + // tslint:disable-next-line:no-any + expect(() => tf.scalar(new Float32Array([1, 2, 3]) as any)).toThrowError(); + }); + it('Scalar inferred dtype from bool', async () => { const a = tf.scalar(true); expect(a.dtype).toBe('bool'); @@ -2153,3 +2260,40 @@ describeWithFlags('tensor with 0 in shape', ALL_ENVS, () => { expectArraysEqual(await a.data(), []); }); }); + +describeWithFlags('tensor.bytes()', ALL_ENVS, () => { + /** Helper method to get the bytes from a typed array. */ + function getBytes(a: TypedArray): Uint8Array { + return new Uint8Array(a.buffer); + } + + it('float32 tensor', async () => { + const a = tf.tensor([1.1, 3.2, 7], [3], 'float32'); + expect(await a.bytes()).toEqual(getBytes(new Float32Array([1.1, 3.2, 7]))); + }); + + it('int32 tensor', async () => { + const a = tf.tensor([1.1, 3.2, 7], [3], 'int32'); + expect(await a.bytes()).toEqual(getBytes(new Int32Array([1, 3, 7]))); + }); + + it('bool tensor', async () => { + const a = tf.tensor([true, true, false], [3], 'bool'); + expect(await a.bytes()).toEqual(new Uint8Array([1, 1, 0])); + }); + + it('string tensor from native strings', async () => { + const a = tf.tensor(['hello', 'world'], [2], 'string'); + expect(await a.bytes()).toEqual([ + encodeString('hello'), encodeString('world') + ]); + }); + + it('string tensor from encoded bytes', async () => { + const a = tf.tensor( + [encodeString('hello'), encodeString('world')], [2], 'string'); + expect(await a.bytes()).toEqual([ + encodeString('hello'), encodeString('world') + ]); + }); +}); diff --git a/src/tensor_util_env.ts b/src/tensor_util_env.ts index 7ea293b2d3..8414272206 100644 --- a/src/tensor_util_env.ts +++ b/src/tensor_util_env.ts @@ -20,7 +20,7 @@ import {Tensor} from './tensor'; import {DataType, TensorLike, TypedArray} from './types'; import {assert, flatten, inferDtype, isTypedArray, toTypedArray} from './util'; -export function inferShape(val: TensorLike): number[] { +export function inferShape(val: TensorLike, dtype?: DataType): number[] { let firstElem: typeof val = val; if (isTypedArray(val)) { @@ -31,7 +31,8 @@ export function inferShape(val: TensorLike): number[] { } const shape: number[] = []; - while (Array.isArray(firstElem) || isTypedArray(firstElem)) { + while (Array.isArray(firstElem) || + isTypedArray(firstElem) && dtype !== 'string') { shape.push(firstElem.length); firstElem = firstElem[0]; } @@ -104,7 +105,7 @@ export function convertToTensor( `Argument '${argName}' passed to '${functionName}' must be a ` + `Tensor or TensorLike, but got '${type}'`); } - const inferredShape = inferShape(x); + const inferredShape = inferShape(x, inferredDtype); if (!isTypedArray(x) && !Array.isArray(x)) { x = [x] as number[]; } diff --git a/src/util.ts b/src/util.ts index f8defc0d9e..569655224e 100644 --- a/src/util.ts +++ b/src/util.ts @@ -145,9 +145,9 @@ flatten|TypedArray>( if (result == null) { result = []; } - if (Array.isArray(arr) || isTypedArray(arr)) { + if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { for (let i = 0; i < arr.length; ++i) { - flatten(arr[i], result); + flatten(arr[i], result, skipTypedArray); } } else { result.push(arr as T); diff --git a/src/util_test.ts b/src/util_test.ts index 573acb879a..d51ea1ab9e 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -87,6 +87,27 @@ describe('Util', () => { const a = new Float32Array([1, 2, 3, 4, 5]); expect(inferShape(a)).toEqual([5]); }); + + it('infer shape of Uint8Array[], string tensor', () => { + const a = [new Uint8Array([1, 2]), new Uint8Array([3, 4])]; + expect(inferShape(a, 'string')).toEqual([2]); + }); + + it('infer shape of Uint8Array[][], string tensor', () => { + const a = [ + [new Uint8Array([1]), new Uint8Array([2])], + [new Uint8Array([1]), new Uint8Array([2])] + ]; + expect(inferShape(a, 'string')).toEqual([2, 2]); + }); + + it('infer shape of Uint8Array[][][], string tensor', () => { + const a = [ + [[new Uint8Array([1, 2])], [new Uint8Array([2, 1])]], + [[new Uint8Array([1, 2])], [new Uint8Array([2, 1])]] + ]; + expect(inferShape(a, 'string')).toEqual([2, 2, 1]); + }); }); describe('util.flatten', () => { @@ -112,6 +133,17 @@ describe('util.flatten', () => { [new Float32Array([1, 2]), 3, [4, 5, new Float32Array([6, 7])]]; expect(util.flatten(data)).toEqual([1, 2, 3, 4, 5, 6, 7]); }); + + it('nested Uint8Arrays, skipTypedArray=true', () => { + const data = [ + [new Uint8Array([1, 2]), new Uint8Array([3, 4])], + [new Uint8Array([5, 6]), new Uint8Array([7, 8])] + ]; + expect(util.flatten(data, [], true)).toEqual([ + new Uint8Array([1, 2]), new Uint8Array([3, 4]), new Uint8Array([5, 6]), + new Uint8Array([7, 8]) + ]); + }); }); function encodeStrings(a: string[]): Uint8Array[] { From 703ff5ef893e52b11a60dd956a7ed911b8d7bfad Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 25 Jun 2019 17:38:41 -0400 Subject: [PATCH 08/12] save --- src/backends/backend.ts | 14 ++++---- src/backends/cpu/backend_cpu.ts | 18 ++++++---- src/backends/cpu/backend_cpu_test.ts | 2 ++ src/backends/webgl/backend_webgl.ts | 8 ++--- src/backends/webgl/tex_util.ts | 4 +-- src/engine.ts | 8 ++--- src/tensor.ts | 18 +++++----- src/types.ts | 4 ++- src/util.ts | 13 +++++++ src/util_test.ts | 52 ++++++++++++++++++++++++++++ 10 files changed, 107 insertions(+), 34 deletions(-) diff --git a/src/backends/backend.ts b/src/backends/backend.ts index 0819d0fcbd..434d0e5fd8 100644 --- a/src/backends/backend.ts +++ b/src/backends/backend.ts @@ -18,7 +18,7 @@ import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util'; import {Activation} from '../ops/fused_util'; import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor'; -import {BackendDataValues, DataType, PixelData, Rank, ShapeMap} from '../types'; +import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types'; export const EPSILON_FLOAT32 = 1e-7; export const EPSILON_FLOAT16 = 1e-4; @@ -31,10 +31,10 @@ export interface BackendTimingInfo { } export interface TensorStorage { - read(dataId: DataId): Promise; - readSync(dataId: DataId): BackendDataValues; + read(dataId: DataId): Promise; + readSync(dataId: DataId): BackendValues; disposeData(dataId: DataId): void; - write(dataId: DataId, values: BackendDataValues): void; + write(dataId: DataId, values: BackendValues): void; fromPixels( pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement| HTMLVideoElement, @@ -92,16 +92,16 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { time(f: () => void): Promise { throw new Error('Not yet implemented.'); } - read(dataId: object): Promise { + read(dataId: object): Promise { throw new Error('Not yet implemented.'); } - readSync(dataId: object): BackendDataValues { + readSync(dataId: object): BackendValues { throw new Error('Not yet implemented.'); } disposeData(dataId: object): void { throw new Error('Not yet implemented.'); } - write(dataId: object, values: BackendDataValues): void { + write(dataId: object, values: BackendValues): void { throw new Error('Not yet implemented.'); } fromPixels( diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index 798de03762..4f341395b2 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -34,7 +34,7 @@ import * as scatter_nd_util from '../../ops/scatter_nd_util'; import * as selu_util from '../../ops/selu_util'; import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util'; import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; -import {BackendDataValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; +import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend'; @@ -58,7 +58,7 @@ function mapActivation( } interface TensorData { - values?: BackendDataValues; + values?: BackendValues; dtype: D; // For complex numbers, the real and imaginary parts are stored as their own // individual tensors, with a parent joining the two with the @@ -116,7 +116,7 @@ export class MathBackendCPU implements KernelBackend { } this.data.set(dataId, {dtype}); } - write(dataId: DataId, values: BackendDataValues): void { + write(dataId: DataId, values: BackendValues): void { if (values == null) { throw new Error('MathBackendCPU.write(): values can not be null'); } @@ -186,10 +186,10 @@ export class MathBackendCPU implements KernelBackend { [pixels.height, pixels.width, numChannels]; return tensor3d(values, outShape, 'int32'); } - async read(dataId: DataId): Promise { + async read(dataId: DataId): Promise { return this.readSync(dataId); } - readSync(dataId: DataId): BackendDataValues { + readSync(dataId: DataId): BackendValues { const {dtype, complexTensors} = this.data.get(dataId); if (dtype === 'complex64') { const realValues = @@ -205,8 +205,12 @@ export class MathBackendCPU implements KernelBackend { const data = this.readSync(t.dataId); let decodedData = data as DataValues; if (t.dtype === 'string') { - // Decode the bytes into string. - decodedData = (data as Uint8Array[]).map(d => util.decodeString(d)); + try { + // Decode the bytes into string. + decodedData = (data as Uint8Array[]).map(d => util.decodeString(d)); + } catch (e) { + throw new Error('Failed to decode encoded string bytes into utf-8'); + } } return buffer(t.shape, t.dtype, decodedData) as TensorBuffer; } diff --git a/src/backends/cpu/backend_cpu_test.ts b/src/backends/cpu/backend_cpu_test.ts index 85324a915b..e3f9579eb3 100644 --- a/src/backends/cpu/backend_cpu_test.ts +++ b/src/backends/cpu/backend_cpu_test.ts @@ -24,10 +24,12 @@ import {decodeString, encodeString} from '../../util'; import {MathBackendCPU} from './backend_cpu'; import {CPU_ENVS} from './backend_cpu_test_registry'; +/** Private test util for encoding array of strings in utf-8. */ function encodeStrings(a: string[]): Uint8Array[] { return a.map(s => encodeString(s)); } +/** Private test util for decoding array of strings in utf-8. */ function decodeStrings(bytes: Uint8Array[]): string[] { return bytes.map(b => decodeString(b)); } diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index 79d40ec9c3..bdad8fb662 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -37,7 +37,7 @@ import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../o import {softmax} from '../../ops/softmax'; import {range, scalar, tensor} from '../../ops/tensor_ops'; import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor'; -import {BackendDataValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types'; +import {BackendValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util'; import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend'; @@ -339,7 +339,7 @@ export class MathBackendWebGL implements KernelBackend { return {dataId, shape, dtype}; } - write(dataId: DataId, values: BackendDataValues): void { + write(dataId: DataId, values: BackendValues): void { if (values == null) { throw new Error('MathBackendWebGL.write(): values can not be null'); } @@ -366,7 +366,7 @@ export class MathBackendWebGL implements KernelBackend { texData.values = values; } - readSync(dataId: DataId): BackendDataValues { + readSync(dataId: DataId): BackendValues { const texData = this.texData.get(dataId); const {values, dtype, complexTensors, slice, shape} = texData; if (slice != null) { @@ -403,7 +403,7 @@ export class MathBackendWebGL implements KernelBackend { return this.convertAndCacheOnCPU(dataId, result); } - async read(dataId: DataId): Promise { + async read(dataId: DataId): Promise { if (this.pendingRead.has(dataId)) { const subscribers = this.pendingRead.get(dataId); return new Promise(resolve => subscribers.push(resolve)); diff --git a/src/backends/webgl/tex_util.ts b/src/backends/webgl/tex_util.ts index b8e06f3000..a425685b23 100644 --- a/src/backends/webgl/tex_util.ts +++ b/src/backends/webgl/tex_util.ts @@ -16,7 +16,7 @@ */ import {DataId, Tensor} from '../../tensor'; -import {BackendDataValues, DataType} from '../../types'; +import {BackendValues, DataType} from '../../types'; import * as util from '../../util'; export enum TextureUsage { @@ -40,7 +40,7 @@ export interface TextureData { dtype: DataType; // Optional. - values?: BackendDataValues; + values?: BackendValues; texture?: WebGLTexture; // For complex numbers, the real and imaginary parts are stored as their own // individual tensors, with a parent joining the two with the diff --git a/src/engine.ts b/src/engine.ts index a018e3fbbb..0240b31fb0 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -22,7 +22,7 @@ import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode import {DataId, setTensorTracker, Tensor, Tensor3D, TensorTracker, Variable} from './tensor'; import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types'; import {getTensorsInContainer} from './tensor_util'; -import {BackendDataValues, DataType, PixelData} from './types'; +import {BackendValues, DataType, PixelData} from './types'; import * as util from './util'; import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util'; @@ -830,7 +830,7 @@ export class Engine implements TensorManager, TensorTracker, DataMover { } // Forwarding to backend. - write(destBackend: KernelBackend, dataId: DataId, values: BackendDataValues): + write(destBackend: KernelBackend, dataId: DataId, values: BackendValues): void { const info = this.state.tensorInfo.get(dataId); @@ -853,12 +853,12 @@ export class Engine implements TensorManager, TensorTracker, DataMover { } destBackend.write(dataId, values); } - readSync(dataId: DataId): BackendDataValues { + readSync(dataId: DataId): BackendValues { // Route the read to the correct backend. const info = this.state.tensorInfo.get(dataId); return info.backend.readSync(dataId); } - read(dataId: DataId): Promise { + read(dataId: DataId): Promise { // Route the read to the correct backend. const info = this.state.tensorInfo.get(dataId); return info.backend.read(dataId); diff --git a/src/tensor.ts b/src/tensor.ts index 6e093029e0..611833ff34 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -16,7 +16,7 @@ */ import {tensorToString} from './tensor_format'; -import {ArrayMap, BackendDataValues, DataType, DataTypeMap, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D, TypedArray} from './types'; +import {ArrayMap, BackendValues, DataType, DataTypeMap, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D, TypedArray} from './types'; import * as util from './util'; import {computeStrides, toNestedArray} from './util'; @@ -28,10 +28,10 @@ export interface TensorData { // This interface mimics KernelBackend (in backend.ts), which would create a // circular dependency if imported. export interface Backend { - read(dataId: object): Promise; - readSync(dataId: object): BackendDataValues; + read(dataId: object): Promise; + readSync(dataId: object): BackendValues; disposeData(dataId: object): void; - write(dataId: object, values: BackendDataValues): void; + write(dataId: object, values: BackendValues): void; } /** @@ -159,9 +159,9 @@ export interface TensorTracker { registerTensor(t: Tensor, backend?: Backend): void; disposeTensor(t: Tensor): void; disposeVariable(v: Variable): void; - write(backend: Backend, dataId: DataId, values: BackendDataValues): void; - read(dataId: DataId): Promise; - readSync(dataId: DataId): BackendDataValues; + write(backend: Backend, dataId: DataId, values: BackendValues): void; + read(dataId: DataId): Promise; + readSync(dataId: DataId): BackendValues; registerVariable(v: Variable): void; nextTensorId(): number; nextVariableId(): number; @@ -452,7 +452,7 @@ export class Tensor { readonly strides: number[]; protected constructor( - shape: ShapeMap[R], dtype: DataType, values?: BackendDataValues, + shape: ShapeMap[R], dtype: DataType, values?: BackendValues, dataId?: DataId, backend?: Backend) { this.shape = shape.slice() as ShapeMap[R]; this.dtype = dtype || 'float32'; @@ -475,7 +475,7 @@ export class Tensor { R extends Rank = Rank>( shape: ShapeMap[R], data: TensorData, dtype?: D, backend?: Backend): T { - let backendVals = data.values as BackendDataValues; + let backendVals = data.values as BackendValues; if (data.values != null && dtype === 'string' && util.isString(data.values[0])) { backendVals = (data.values as string[]).map(d => util.encodeString(d)); diff --git a/src/types.ts b/src/types.ts index 6c48f8ead2..8e0848d839 100644 --- a/src/types.ts +++ b/src/types.ts @@ -57,8 +57,10 @@ export interface SingleValueMap { export type DataType = keyof DataTypeMap; export type NumericDataType = 'float32'|'int32'|'bool'|'complex64'; export type TypedArray = Float32Array|Int32Array|Uint8Array; +/** Tensor data used in tensor creation and user-facing API. */ export type DataValues = DataTypeMap[DataType]; -export type BackendDataValues = Float32Array|Int32Array|Uint8Array|Uint8Array[]; +/** The underlying tensor data that gets stored in a backend. */ +export type BackendValues = Float32Array|Int32Array|Uint8Array|Uint8Array[]; export enum Rank { R0 = 'R0', diff --git a/src/util.ts b/src/util.ts index 569655224e..c018b80928 100644 --- a/src/util.ts +++ b/src/util.ts @@ -695,11 +695,24 @@ export function fetch( return ENV.platform.fetch(path, requestInits); } +/** + * Encodes the provided string into bytes using the provided encoding scheme. + * + * @param s The string to encode. + * @param encoding The encoding scheme. Defaults to utf-8. + * + */ export function encodeString(s: string, encoding = 'utf-8'): Uint8Array { encoding = encoding || 'utf-8'; return ENV.platform.encode(s, encoding); } +/** + * Decodes the provided bytes into a string using the provided encoding scheme. + * @param bytes The bytes to decode. + * + * @param encoding The encoding scheme. Defaults to utf-8. + */ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string { encoding = encoding || 'utf-8'; return ENV.platform.decode(bytes, encoding); diff --git a/src/util_test.ts b/src/util_test.ts index d51ea1ab9e..01f0394664 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -549,3 +549,55 @@ describe('util.fetch', () => { }); }); }); + +describe('util.encodeString', () => { + it('Encodes an empty string, default encoding', () => { + const res = util.encodeString(''); + expect(res).toEqual(new Uint8Array([])); + }); + + it('Encodes an empty string, utf-8 encoding', () => { + const res = util.encodeString('', 'utf-8'); + expect(res).toEqual(new Uint8Array([])); + }); + + it('Encodes an empty string, encoding must be utf-8', () => { + expect(() => util.encodeString('', 'utf-16')) + .toThrowError(/Browser's encoder only supports utf-8, but got utf-16/); + }); + + it('Encodes a cyrillic string', () => { + const res = util.encodeString('Kaкo стe'); + expect(res).toEqual( + new Uint8Array([75, 97, 208, 186, 111, 32, 209, 129, 209, 130, 101])); + }); + + it('Encodes ascii', () => { + const res = util.encodeString('hello'); + expect(res).toEqual(new Uint8Array([104, 101, 108, 108, 111])); + }); +}); + +describe('util.decodeString', () => { + it('decode an empty string', () => { + const s = util.decodeString(new Uint8Array([])); + expect(s).toEqual(''); + }); + + it('decode ascii', () => { + const s = util.decodeString(new Uint8Array([104, 101, 108, 108, 111])); + expect(s).toEqual('hello'); + }); + + it('decode cyrillic', () => { + const s = util.decodeString( + new Uint8Array([75, 97, 208, 186, 111, 32, 209, 129, 209, 130, 101])); + expect(s).toEqual('Kaкo стe'); + }); + + // tslint:disable-next-line: ban + fit('decode utf-16-be', + () => { + // TODO: implement. + }); +}); From b1ca5ee2d58b75c9c8cf09d51fbd6fa41b0c4a56 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 25 Jun 2019 21:22:52 -0400 Subject: [PATCH 09/12] save --- src/io/io.ts | 3 +- src/io/io_utils.ts | 44 +++++++++------- src/io/io_utils_test.ts | 112 ++++++++++++++++------------------------ src/io/types.ts | 16 ------ src/util_test.ts | 20 +++---- 5 files changed, 82 insertions(+), 113 deletions(-) diff --git a/src/io/io.ts b/src/io/io.ts index 3e6fa1a067..5a44417b9f 100644 --- a/src/io/io.ts +++ b/src/io/io.ts @@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http'; import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils'; import {fromMemory, withSaveHandler} from './passthrough'; import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; -import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, StringWeightsManifestEntry, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {loadWeights, weightsLoaderFactory} from './weights_loader'; export {copyModel, listModels, moveModel, removeModel} from './model_management'; @@ -54,7 +54,6 @@ export { SaveConfig, SaveHandler, SaveResult, - StringWeightsManifestEntry, WeightGroup, weightsLoaderFactory, WeightsManifestConfig, diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 40f856248f..bcf16a8b77 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -15,16 +15,12 @@ * ============================================================================= */ -import {ENV} from '../environment'; import {tensor} from '../ops/tensor_ops'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; -import {encodeString, sizeFromShape} from '../util'; +import {sizeFromShape} from '../util'; -import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, StringWeightsManifestEntry, WeightGroup, WeightsManifestEntry} from './types'; - -/** Used to delimit neighboring strings when encoding string tensors. */ -export const STRING_DELIMITER = '\x00'; +import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightGroup, WeightsManifestEntry} from './types'; /** * Encode a map from names to weight values as an ArrayBuffer, along with an @@ -64,11 +60,20 @@ export async function encodeWeights( const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype}; if (t.dtype === 'string') { const utf8bytes = new Promise(async resolve => { - const stringSpec = spec as StringWeightsManifestEntry; - const data = await t.data(); - const bytes = encodeString(data.join(STRING_DELIMITER)); - stringSpec.byteLength = bytes.length; - stringSpec.delimiter = STRING_DELIMITER; + const vals = await t.bytes() as Uint8Array[]; + const totalNumBytes = + vals.reduce((p, c) => p + c.length, 0) + 4 * vals.length; + const bytes = new Uint8Array(totalNumBytes); + let offset = 0; + for (let i = 0; i < vals.length; i++) { + const val = vals[i]; + const bytesOfLength = + new Uint8Array(new Int32Array([val.length]).buffer); + bytes.set(bytesOfLength, offset); + offset += 4; + bytes.set(val, offset); + offset += val.length; + } resolve(bytes); }); dataPromises.push(utf8bytes); @@ -110,7 +115,7 @@ export function decodeWeights( const dtype = spec.dtype; const shape = spec.shape; const size = sizeFromShape(shape); - let values: TypedArray|string[]; + let values: TypedArray|string[]|Uint8Array[]; if ('quantization' in spec) { const quantization = spec.quantization; @@ -138,12 +143,15 @@ export function decodeWeights( } offset += size * quantizationSizeFactor; } else if (dtype === 'string') { - const stringSpec = spec as StringWeightsManifestEntry; - const bytes = - new Uint8Array(buffer.slice(offset, offset + stringSpec.byteLength)); - // TODO(smilkov): Use encoding from metadata. - values = ENV.platform.decode(bytes, 'utf-8').split(stringSpec.delimiter); - offset += stringSpec.byteLength; + const size = sizeFromShape(spec.shape); + values = []; + for (let i = 0; i < size; i++) { + const byteLength = new Int32Array(buffer.slice(offset, offset + 4))[0]; + offset += 4; + const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); + (values as Uint8Array[]).push(bytes); + offset += byteLength; + } } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 60dff7db88..3f6a57b6bf 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -21,8 +21,8 @@ import {scalar, tensor1d, tensor2d} from '../ops/ops'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {expectArraysEqual} from '../test_util'; import {expectArraysClose} from '../test_util'; - -import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, STRING_DELIMITER, stringByteLength} from './io_utils'; +import {encodeString} from '../util'; +import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils'; import {WeightsManifestEntry} from './types'; describe('concatenateTypedArrays', () => { @@ -339,69 +339,54 @@ describe('encodeWeights', () => { }; const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data; - const specs = dataAndSpecs.specs as tf.io.StringWeightsManifestEntry[]; - const x1ByteLength = 7 + 3; // 7 ascii chars + 3 delimiters. - const x2ByteLength = 0; // No chars. - const x3ByteLength = 13 * 2 + 1; // 13 cyrillic letters + 1 delimiter. - const x4ByteLength = 6; // 2 chinese letters. - const x5ByteLength = 5; // 5 ascii chars. + const specs = dataAndSpecs.specs; + const x1ByteLength = 7 + 4 * 4; // 7 ascii chars + 4 ints. + const x2ByteLength = 4; // No chars + 1 int. + const x3ByteLength = 13 * 2 + 2 * 4; // 13 cyrillic letters + 2 ints. + const x4ByteLength = 6 + 1 * 4; // 2 east asian letters + 1 int. + const x5ByteLength = 5 + 1 * 4; // 5 ascii chars + 1 int. expect(data.byteLength) .toEqual( x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength + x5ByteLength); - let delim = specs[0].delimiter; - expect(new Uint8Array(data, 0, x1ByteLength)) - .toEqual( - tf.ENV.platform.encode(`a${delim}bc${delim}def${delim}g`, 'utf-8')); - // The middle string takes up 0 bytes. - delim = specs[2].delimiter; - expect(new Uint8Array(data, x1ByteLength + x2ByteLength, x3ByteLength)) - .toEqual(tf.ENV.platform.encode(`здраво${delim}поздрав`, 'utf-8')); - delim = specs[3].delimiter; - expect(new Uint8Array( - data, x1ByteLength + x2ByteLength + x3ByteLength, x4ByteLength)) - .toEqual(tf.ENV.platform.encode('正常', 'utf-8')); - delim = specs[4].delimiter; - expect(new Uint8Array( - data, x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength, - x5ByteLength)) - .toEqual(tf.ENV.platform.encode('hello', 'utf-8')); + // x1 'a'. + expect(new Int32Array(data, 0, 1)[0]).toBe(1); + expect(new Uint8Array(data, 4, 1)).toEqual(encodeString('a')); + // x1 'bc'. + expect(new Int32Array(data.slice(5, 9))[0]).toBe(2); + expect(new Uint8Array(data, 9, 2)).toEqual(encodeString('bc')); + // x1 'def'. + expect(new Int32Array(data.slice(11, 15))[0]).toBe(3); + expect(new Uint8Array(data, 15, 3)).toEqual(encodeString('def')); + // x1 'g'. + expect(new Int32Array(data.slice(18, 22))[0]).toBe(1); + expect(new Uint8Array(data, 22, 1)).toEqual(encodeString('g')); + + // x2 is empty string. + expect(new Int32Array(data.slice(23, 27))[0]).toBe(0); + + // x3 'здраво'. + expect(new Int32Array(data.slice(27, 31))[0]).toBe(12); + expect(new Uint8Array(data, 31, 12)).toEqual(encodeString('здраво')); + + // x3 'поздрав'. + expect(new Int32Array(data.slice(43, 47))[0]).toBe(14); + expect(new Uint8Array(data, 47, 14)).toEqual(encodeString('поздрав')); + + // x4 '正常'. + expect(new Int32Array(data.slice(61, 65))[0]).toBe(6); + expect(new Uint8Array(data, 65, 6)).toEqual(encodeString('正常')); + + // x5 'hello'. + expect(new Int32Array(data.slice(71, 75))[0]).toBe(5); + expect(new Uint8Array(data, 75, 5)).toEqual(encodeString('hello')); + expect(specs).toEqual([ - { - name: 'x1', - dtype: 'string', - shape: [2, 2], - byteLength: x1ByteLength, - delimiter: STRING_DELIMITER, - }, - { - name: 'x2', - dtype: 'string', - shape: [], - byteLength: x2ByteLength, - delimiter: STRING_DELIMITER, - }, - { - name: 'x3', - dtype: 'string', - shape: [2], - byteLength: x3ByteLength, - delimiter: STRING_DELIMITER, - }, - { - name: 'x4', - dtype: 'string', - shape: [], - byteLength: x4ByteLength, - delimiter: STRING_DELIMITER, - }, - { - name: 'x5', - dtype: 'string', - shape: [], - byteLength: x5ByteLength, - delimiter: STRING_DELIMITER, - } + {name: 'x1', dtype: 'string', shape: [2, 2]}, + {name: 'x2', dtype: 'string', shape: []}, + {name: 'x3', dtype: 'string', shape: [2]}, + {name: 'x4', dtype: 'string', shape: []}, + {name: 'x5', dtype: 'string', shape: []} ]); }); @@ -454,13 +439,6 @@ describeWithFlags('decodeWeights', {}, () => { const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; - // 12 bytes from cyrillic (6 letters) + 3 bytes from ascii + 3 delimiters. - const x4Bytes = 12 + 3 + 3; - const x5Bytes = 0; - // 5 bytes from ascii. - const x6Bytes = 5; - expect(data.byteLength) - .toEqual(4 * 4 + 4 * 1 + 1 * 3 + x4Bytes + x5Bytes + x6Bytes + 4 * 3); const decoded = tf.io.decodeWeights(data, specs); expect(Object.keys(decoded).length).toEqual(7); expectArraysEqual(await decoded['x1'].data(), await tensors['x1'].data()); diff --git a/src/io/types.ts b/src/io/types.ts index fde040fcd2..275308dd54 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -107,22 +107,6 @@ export declare interface WeightsManifestEntry { }; } -export declare interface StringWeightsManifestEntry extends - WeightsManifestEntry { - dtype: 'string'; - /** - * Used for delimiting neighboring strings. If the tensor has no strings or - * only 1 string, there will be no delimiter. If the tensor has N strings - * (N>0), there will be N-1 delimiters. - */ - delimiter: string; - /** - * Number of bytes used by the whole tensor, including the delimiters (N-1 - * delimiters for N strings). - */ - byteLength: number; -} - /** * Options for saving a model. * @innamespace io diff --git a/src/util_test.ts b/src/util_test.ts index 01f0394664..6ebf64b609 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -551,28 +551,28 @@ describe('util.fetch', () => { }); describe('util.encodeString', () => { - it('Encodes an empty string, default encoding', () => { + it('Encode an empty string, default encoding', () => { const res = util.encodeString(''); expect(res).toEqual(new Uint8Array([])); }); - it('Encodes an empty string, utf-8 encoding', () => { + it('Encode an empty string, utf-8 encoding', () => { const res = util.encodeString('', 'utf-8'); expect(res).toEqual(new Uint8Array([])); }); - it('Encodes an empty string, encoding must be utf-8', () => { + it('Encode an empty string, encoding must be utf-8', () => { expect(() => util.encodeString('', 'utf-16')) .toThrowError(/Browser's encoder only supports utf-8, but got utf-16/); }); - it('Encodes a cyrillic string', () => { + it('Encode cyrillic letters', () => { const res = util.encodeString('Kaкo стe'); expect(res).toEqual( new Uint8Array([75, 97, 208, 186, 111, 32, 209, 129, 209, 130, 101])); }); - it('Encodes ascii', () => { + it('Encode ascii letters', () => { const res = util.encodeString('hello'); expect(res).toEqual(new Uint8Array([104, 101, 108, 108, 111])); }); @@ -595,9 +595,9 @@ describe('util.decodeString', () => { expect(s).toEqual('Kaкo стe'); }); - // tslint:disable-next-line: ban - fit('decode utf-16-be', - () => { - // TODO: implement. - }); + it('decode utf-16', () => { + const s = util.decodeString( + new Uint8Array([255, 254, 237, 139, 0, 138, 4, 89, 6, 116]), 'utf-16'); + expect(s).toEqual('语言处理'); + }); }); From 825adb2870672ebb19ff00f390c4266626fd6b47 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 26 Jun 2019 09:42:48 -0400 Subject: [PATCH 10/12] save --- src/backends/cpu/backend_cpu.ts | 2 +- src/tensor.ts | 16 ++++++++++++++-- src/tensor_test.ts | 2 +- src/tensor_util_env.ts | 2 +- src/util.ts | 2 ++ src/util_test.ts | 2 +- 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index 4f341395b2..3bcda167bf 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -208,7 +208,7 @@ export class MathBackendCPU implements KernelBackend { try { // Decode the bytes into string. decodedData = (data as Uint8Array[]).map(d => util.decodeString(d)); - } catch (e) { + } catch { throw new Error('Failed to decode encoded string bytes into utf-8'); } } diff --git a/src/tensor.ts b/src/tensor.ts index 611833ff34..0cd8fdbe57 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -618,7 +618,13 @@ export class Tensor { const data = trackerFn().read(this.dataId); if (this.dtype === 'string') { const bytes = await data as Uint8Array[]; - return bytes.map(b => util.decodeString(b)); + try { + return bytes.map(b => util.decodeString(b)); + } catch { + throw new Error( + 'Failed to decode the string bytes into utf-8. ' + + 'To get the original bytes, call tensor.bytes().'); + } } return data as Promise; } @@ -632,7 +638,13 @@ export class Tensor { this.throwIfDisposed(); const data = trackerFn().readSync(this.dataId); if (this.dtype === 'string') { - return (data as Uint8Array[]).map(b => util.decodeString(b)); + try { + return (data as Uint8Array[]).map(b => util.decodeString(b)); + } catch { + throw new Error( + 'Failed to decode the string bytes into utf-8. ' + + 'To get the original bytes, call tensor.bytes().'); + } } return data as DataTypeMap[D]; } diff --git a/src/tensor_test.ts b/src/tensor_test.ts index 8f3edd8137..5f92b68ae7 100644 --- a/src/tensor_test.ts +++ b/src/tensor_test.ts @@ -22,7 +22,7 @@ import {expectArraysClose, expectArraysEqual, expectNumbersClose} from './test_u import {Rank, RecursiveArray, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TypedArray} from './types'; import {encodeString} from './util'; -/** Util method used for tests. It encodes each string into utf-8 bytes. */ +/** Private method used by these tests. Encodes strings into utf-8 bytes. */ function encodeStrings(a: RecursiveArray<{}>): RecursiveArray { for (let i = 0; i < (a as Array<{}>).length; i++) { const val = a[i]; diff --git a/src/tensor_util_env.ts b/src/tensor_util_env.ts index 8414272206..f411418fa8 100644 --- a/src/tensor_util_env.ts +++ b/src/tensor_util_env.ts @@ -24,7 +24,7 @@ export function inferShape(val: TensorLike, dtype?: DataType): number[] { let firstElem: typeof val = val; if (isTypedArray(val)) { - return [(val as TypedArray).length]; + return dtype === 'string' ? [] : [(val as TypedArray).length]; } if (!Array.isArray(val)) { return []; // Scalar. diff --git a/src/util.ts b/src/util.ts index c018b80928..3e91f41093 100644 --- a/src/util.ts +++ b/src/util.ts @@ -702,6 +702,7 @@ export function fetch( * @param encoding The encoding scheme. Defaults to utf-8. * */ +/** @doc {heading: 'Util'} */ export function encodeString(s: string, encoding = 'utf-8'): Uint8Array { encoding = encoding || 'utf-8'; return ENV.platform.encode(s, encoding); @@ -713,6 +714,7 @@ export function encodeString(s: string, encoding = 'utf-8'): Uint8Array { * * @param encoding The encoding scheme. Defaults to utf-8. */ +/** @doc {heading: 'Util'} */ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string { encoding = encoding || 'utf-8'; return ENV.platform.decode(bytes, encoding); diff --git a/src/util_test.ts b/src/util_test.ts index 6ebf64b609..7ac5ab1821 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -563,7 +563,7 @@ describe('util.encodeString', () => { it('Encode an empty string, encoding must be utf-8', () => { expect(() => util.encodeString('', 'utf-16')) - .toThrowError(/Browser's encoder only supports utf-8, but got utf-16/); + .toThrowError(/only supports utf-8, but got utf-16/); }); it('Encode cyrillic letters', () => { From 3af5b724319e28f3b74d5042b92aaaf6a51ee6b9 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 26 Jun 2019 12:08:21 -0400 Subject: [PATCH 11/12] save --- src/io/io_utils.ts | 14 +++++++++----- src/ops/tensor_ops.ts | 5 +++-- src/tensor_util_env.ts | 3 ++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index bcf16a8b77..20e24497ce 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -22,6 +22,9 @@ import {sizeFromShape} from '../util'; import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightGroup, WeightsManifestEntry} from './types'; +/** Number of bytes reserved for the length of the string. (32bit integer). */ +const NUM_BYTES_STRING_LENGTH = 4; + /** * Encode a map from names to weight values as an ArrayBuffer, along with an * `Array` of `WeightsManifestEntry` as specification of the encoded weights. @@ -61,8 +64,8 @@ export async function encodeWeights( if (t.dtype === 'string') { const utf8bytes = new Promise(async resolve => { const vals = await t.bytes() as Uint8Array[]; - const totalNumBytes = - vals.reduce((p, c) => p + c.length, 0) + 4 * vals.length; + const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) + + NUM_BYTES_STRING_LENGTH * vals.length; const bytes = new Uint8Array(totalNumBytes); let offset = 0; for (let i = 0; i < vals.length; i++) { @@ -70,7 +73,7 @@ export async function encodeWeights( const bytesOfLength = new Uint8Array(new Int32Array([val.length]).buffer); bytes.set(bytesOfLength, offset); - offset += 4; + offset += NUM_BYTES_STRING_LENGTH; bytes.set(val, offset); offset += val.length; } @@ -146,8 +149,9 @@ export function decodeWeights( const size = sizeFromShape(spec.shape); values = []; for (let i = 0; i < size; i++) { - const byteLength = new Int32Array(buffer.slice(offset, offset + 4))[0]; - offset += 4; + const byteLength = new Int32Array( + buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + offset += NUM_BYTES_STRING_LENGTH; const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); (values as Uint8Array[]).push(bytes); offset += byteLength; diff --git a/src/ops/tensor_ops.ts b/src/ops/tensor_ops.ts index 0ed49eac5b..71bbb92f07 100644 --- a/src/ops/tensor_ops.ts +++ b/src/ops/tensor_ops.ts @@ -45,7 +45,8 @@ import {op} from './operation'; * ``` * * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. + * or a flat array, or a `TypedArray`. If the values are strings, + * they will be encoded as utf-8 and kept as `Uint8Array[]`. * @param shape The shape of the tensor. Optional. If not provided, * it is inferred from `values`. * @param dtype The data type. @@ -137,7 +138,7 @@ function scalar( !(value instanceof Uint8Array)) { throw new Error( 'When making a scalar from encoded string, ' + - 'the value must be Uint8Array'); + 'the value must be `Uint8Array`.'); } const shape: number[] = []; const inferredShape: number[] = []; diff --git a/src/tensor_util_env.ts b/src/tensor_util_env.ts index f411418fa8..aa3109217f 100644 --- a/src/tensor_util_env.ts +++ b/src/tensor_util_env.ts @@ -109,9 +109,10 @@ export function convertToTensor( if (!isTypedArray(x) && !Array.isArray(x)) { x = [x] as number[]; } + const skipTypedArray = true; const values = inferredDtype !== 'string' ? toTypedArray(x, inferredDtype as DataType, ENV.getBool('DEBUG')) : - flatten(x as string[], [], true) as string[]; + flatten(x as string[], [], skipTypedArray) as string[]; return Tensor.make(inferredShape, {values}, inferredDtype); } From 301215fd4065a417bf752975e0296897cc72a636 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 26 Jun 2019 14:00:36 -0400 Subject: [PATCH 12/12] save --- src/io/io_utils.ts | 4 ++-- src/io/io_utils_test.ts | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 20e24497ce..a4a3b26763 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -71,7 +71,7 @@ export async function encodeWeights( for (let i = 0; i < vals.length; i++) { const val = vals[i]; const bytesOfLength = - new Uint8Array(new Int32Array([val.length]).buffer); + new Uint8Array(new Uint32Array([val.length]).buffer); bytes.set(bytesOfLength, offset); offset += NUM_BYTES_STRING_LENGTH; bytes.set(val, offset); @@ -149,7 +149,7 @@ export function decodeWeights( const size = sizeFromShape(spec.shape); values = []; for (let i = 0; i < size; i++) { - const byteLength = new Int32Array( + const byteLength = new Uint32Array( buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; offset += NUM_BYTES_STRING_LENGTH; const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 3f6a57b6bf..135eeecb98 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -350,35 +350,35 @@ describe('encodeWeights', () => { x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength + x5ByteLength); // x1 'a'. - expect(new Int32Array(data, 0, 1)[0]).toBe(1); + expect(new Uint32Array(data, 0, 1)[0]).toBe(1); expect(new Uint8Array(data, 4, 1)).toEqual(encodeString('a')); // x1 'bc'. - expect(new Int32Array(data.slice(5, 9))[0]).toBe(2); + expect(new Uint32Array(data.slice(5, 9))[0]).toBe(2); expect(new Uint8Array(data, 9, 2)).toEqual(encodeString('bc')); // x1 'def'. - expect(new Int32Array(data.slice(11, 15))[0]).toBe(3); + expect(new Uint32Array(data.slice(11, 15))[0]).toBe(3); expect(new Uint8Array(data, 15, 3)).toEqual(encodeString('def')); // x1 'g'. - expect(new Int32Array(data.slice(18, 22))[0]).toBe(1); + expect(new Uint32Array(data.slice(18, 22))[0]).toBe(1); expect(new Uint8Array(data, 22, 1)).toEqual(encodeString('g')); // x2 is empty string. - expect(new Int32Array(data.slice(23, 27))[0]).toBe(0); + expect(new Uint32Array(data.slice(23, 27))[0]).toBe(0); // x3 'здраво'. - expect(new Int32Array(data.slice(27, 31))[0]).toBe(12); + expect(new Uint32Array(data.slice(27, 31))[0]).toBe(12); expect(new Uint8Array(data, 31, 12)).toEqual(encodeString('здраво')); // x3 'поздрав'. - expect(new Int32Array(data.slice(43, 47))[0]).toBe(14); + expect(new Uint32Array(data.slice(43, 47))[0]).toBe(14); expect(new Uint8Array(data, 47, 14)).toEqual(encodeString('поздрав')); // x4 '正常'. - expect(new Int32Array(data.slice(61, 65))[0]).toBe(6); + expect(new Uint32Array(data.slice(61, 65))[0]).toBe(6); expect(new Uint8Array(data, 65, 6)).toEqual(encodeString('正常')); // x5 'hello'. - expect(new Int32Array(data.slice(71, 75))[0]).toBe(5); + expect(new Uint32Array(data.slice(71, 75))[0]).toBe(5); expect(new Uint8Array(data, 75, 5)).toEqual(encodeString('hello')); expect(specs).toEqual([