-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
Hello !
Do you know how can I access the SignatureDef to be able to reuse the output to refeed the input like in the MoviNet architecture; I need to know the real name of the output I think to map them to the inputs.
Here a related issue: https://stackoverflow.com/questions/68180174/outputs-of-tensorflow-lite-model-look-like-statefulpartitionedcalln-what-is-no
Is there any access to api like interpreter.get_signature_runner()
?
import { resolve } from "path";
import "@tensorflow/tfjs-backend-cpu";
import * as tf from "@tensorflow/tfjs-core";
import * as tflite from "tfjs-tflite-node";
const model = resolve(
"static/lite-model_movinet_a1_stream_kinetics-600_classification_tflite_float16_2.tflite"
);
const find = (d, obj) => Object.entries(obj).find(([k]) => d.includes(k))[1];
const tfliteModel = await tflite.loadTFLiteModel(model);
const getName = (name) => name.slice("serving_default_".length, -":0".length);
const inputs = tfliteModel.modelRunner
.getInputs()
.map((input) => ({ input }))
.map(({ input, input: { name, shape, dataType: d, quantization } }) => ({
input,
shape,
quantization,
dtype: find(d, { int: "int32", float: "float32", bool: "bool" }),
name: getName(name),
}));
console.log(inputs.map(({ name }) => name));
const quantizedScale = (name, state) => {
const { dtype, quantization: [scale, zeroPoint] = [0, 0] } = inputs.find(
({ name: v }) => v === name
);
return name.includes("frame_count") || dtype === "float32" || scale === 0
? state
: tf.cast(tf.sum(tf.div(state, scale), zeroPoint), dtype);
};
const initStates = inputs.map(({ name, shape, dtype }) =>
quantizedScale(name, tf.zeros(shape, dtype))
);
const video = tf.ones([1, 5, 172, 172, 3]);
const clips = tf.split(video, video.shape[1], 1);
let states = initStates;
for (const clip of clips) {
const frame = quantizedScale("image", clip);
inputs.forEach(({ name, input }, i) =>
name === "image" ? input.data().set(frame) : input.data().set(states[i])
);
tfliteModel.modelRunner.infer();
states = tfliteModel.modelRunner.getOutputs();
}
Metadata
Metadata
Assignees
Labels
No labels