Update Node to 1.1.0 and update to reflect changes from core. by nsthorat · Pull Request #238 · tensorflow/tfjs-node · GitHub
Skip to content
This repository was archived by the owner on Sep 17, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions package.json
11 changes: 4 additions & 7 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import * as tf from '@tensorflow/tfjs';

import {nodeFileSystemRouter} from './io/file_system';
import * as nodeIo from './io/index';
import {nodeHTTPRequestRouter} from './io/node_http';
import {NodeJSKernelBackend} from './nodejs_kernel_backend';
import * as nodeVersion from './version';

Expand All @@ -43,21 +42,19 @@ export * from '@tensorflow/tfjs';
// tslint:disable-next-line:no-require-imports
const pjson = require('../package.json');

tf.ENV.registerBackend('tensorflow', () => {
tf.registerBackend('tensorflow', () => {
return new NodeJSKernelBackend(
bindings('tfjs_binding.node') as TFJSBinding, pjson.name);
}, 3 /* priority */);

// If registration succeeded, set the backend.
if (tf.ENV.findBackend('tensorflow') != null) {
tf.setBackend('tensorflow');
const success = tf.setBackend('tensorflow');
if (!success) {
throw new Error(`Could not initialize TensorFlow backend.`);
}

// Register the model saving and loading handlers for the 'file://' URL scheme.
tf.io.registerLoadRouter(nodeFileSystemRouter);
tf.io.registerSaveRouter(nodeFileSystemRouter);
tf.io.registerLoadRouter(nodeHTTPRequestRouter);
// TODO(cais): Make HTTP-based save work from Node.js.

import {ProgbarLogger} from './callbacks';
// Register the ProgbarLogger for Model.fit() at verbosity level 1.
Expand Down
10 changes: 1 addition & 9 deletions src/io/node_http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@

import {io} from '@tensorflow/tfjs-core';

// tslint:disable-next-line:no-require-imports
const fetch = require('node-fetch');

// For testing: Enables jasmine `spyOn()` with `fetch`.
export const fetchWrapper = {fetch};

/**
* Factory function for HTTP IO Handler in Node.js.
*
Expand All @@ -34,9 +28,7 @@ export const fetchWrapper = {fetch};
export function nodeHTTPRequest(
path: string, requestInit?: RequestInit,
weightPathPrefix?: string): io.IOHandler {
return io.browserHTTPRequest(
path as string,
{requestInit, weightPathPrefix, fetchFunc: fetchWrapper.fetch});
return io.browserHTTPRequest(path as string, {requestInit, weightPathPrefix});
}

export const nodeHTTPRequestRouter = (url: string) => {
Expand Down
38 changes: 18 additions & 20 deletions src/io/node_http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import * as tfl from '@tensorflow/tfjs-layers';

import * as tfn from '../index';

import {fetchWrapper} from './node_http';

// We still need node-fetch so that we can mock the core tfc.util.fetch call and
// return a valid response.
// tslint:disable-next-line:no-require-imports
const fetch = require('node-fetch');

Expand Down Expand Up @@ -68,24 +68,22 @@ describe('nodeHTTPRequest-load', () => {
[filename: string]: string|Float32Array|Int32Array|ArrayBuffer|Uint8Array|
Uint16Array
}) => {
spyOn(fetchWrapper, 'fetch')
.and.callFake((path: string, init: RequestInit) => {
return new Promise((resolve, reject) => {
let contentType = '';
if (path.endsWith('model.json')) {
contentType = JSON_TYPE;
} else if (
path.endsWith('weightfile0') || path.endsWith('weightfile1')) {
contentType = OCTET_STREAM_TYPE;
} else {
reject(new Error(`Invalid path: ${path}`));
}
requestInits.push(init);
resolve(new fetch.Response(
fileBufferMap[path],
{'headers': {'Content-Type': contentType}}));
});
});
spyOn(tfc.util, 'fetch').and.callFake((path: string, init: RequestInit) => {
return new Promise((resolve, reject) => {
let contentType = '';
if (path.endsWith('model.json')) {
contentType = JSON_TYPE;
} else if (
path.endsWith('weightfile0') || path.endsWith('weightfile1')) {
contentType = OCTET_STREAM_TYPE;
} else {
reject(new Error(`Invalid path: ${path}`));
}
requestInits.push(init);
resolve(new fetch.Response(
fileBufferMap[path], {'headers': {'Content-Type': contentType}}));
});
});
};

beforeEach(() => {
Expand Down
2 changes: 1 addition & 1 deletion src/nodejs_kernel_backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ export class NodeJSKernelBackend extends KernelBackend {
return this.executeSingleOutput(name, opAttrs, [input]);
}

floatPrecision(): number {
floatPrecision(): 16|32 {
return 32;
}

Expand Down
2 changes: 1 addition & 1 deletion src/nodejs_kernel_backend_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ describe('conv3d dilations', () => {
it('GPU should handle dilations >1', () => {
// This test can only run locally with CUDA bindings and GPU package
// installed.
if ((tf.ENV.backend as NodeJSKernelBackend).isGPUPackage) {
if ((tf.backend() as NodeJSKernelBackend).isGPUPackage) {
const input = tf.ones([1, 2, 2, 2, 1]) as Tensor5D;
const filter = tf.ones([1, 1, 1, 1, 1]) as Tensor5D;
tf.conv3d(input, filter, 1, 'same', 'NHWC', [2, 2, 2]);
Expand Down
2 changes: 1 addition & 1 deletion src/ops/op_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ let gBackend: NodeJSKernelBackend = null;
/** Returns an instance of the Node.js backend. */
export function nodeBackend(): NodeJSKernelBackend {
if (gBackend === null) {
gBackend = (tfc.ENV.findBackend('tensorflow') as NodeJSKernelBackend);
gBackend = (tfc.findBackend('tensorflow') as NodeJSKernelBackend);
}
return gBackend;
}
Expand Down
2 changes: 1 addition & 1 deletion src/run_tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ process.on('unhandledRejection', e => {
});

jasmine_util.setTestEnvs(
[{name: 'test-tensorflow', factory: () => nodeBackend(), features: {}}]);
[{name: 'test-tensorflow', backendName: 'tensorflow', flags: {}}]);

const IGNORE_LIST: string[] = [
// Always ignore version tests:
Expand Down
2 changes: 1 addition & 1 deletion src/version.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/** @license See the LICENSE file. */

// This code is auto-generated, do not modify this file!
const version = '1.0.3';
const version = '1.1.0';
export {version};
89 changes: 59 additions & 30 deletions yarn.lock