这是indexloc提供的服务,不要输入任何密码
Skip to content

Access the SignatureDef #49

@nestarz

Description

@nestarz

Hello !

Do you know how can I access the SignatureDef to be able to reuse the output to refeed the input like in the MoviNet architecture; I need to know the real name of the output I think to map them to the inputs.

Here a related issue: https://stackoverflow.com/questions/68180174/outputs-of-tensorflow-lite-model-look-like-statefulpartitionedcalln-what-is-no

Is there any access to api like interpreter.get_signature_runner()?

import { resolve } from "path";
import "@tensorflow/tfjs-backend-cpu";
import * as tf from "@tensorflow/tfjs-core";
import * as tflite from "tfjs-tflite-node";

const model = resolve(
  "static/lite-model_movinet_a1_stream_kinetics-600_classification_tflite_float16_2.tflite"
);
const find = (d, obj) => Object.entries(obj).find(([k]) => d.includes(k))[1];
const tfliteModel = await tflite.loadTFLiteModel(model);
const getName = (name) => name.slice("serving_default_".length, -":0".length);
const inputs = tfliteModel.modelRunner
  .getInputs()
  .map((input) => ({ input }))
  .map(({ input, input: { name, shape, dataType: d, quantization } }) => ({
    input,
    shape,
    quantization,
    dtype: find(d, { int: "int32", float: "float32", bool: "bool" }),
    name: getName(name),
  }));

console.log(inputs.map(({ name }) => name));

const quantizedScale = (name, state) => {
  const { dtype, quantization: [scale, zeroPoint] = [0, 0] } = inputs.find(
    ({ name: v }) => v === name
  );
  return name.includes("frame_count") || dtype === "float32" || scale === 0
    ? state
    : tf.cast(tf.sum(tf.div(state, scale), zeroPoint), dtype);
};

const initStates = inputs.map(({ name, shape, dtype }) =>
  quantizedScale(name, tf.zeros(shape, dtype))
);
const video = tf.ones([1, 5, 172, 172, 3]);
const clips = tf.split(video, video.shape[1], 1);

let states = initStates;
for (const clip of clips) {
  const frame = quantizedScale("image", clip);
  inputs.forEach(({ name, input }, i) =>
    name === "image" ? input.data().set(frame) : input.data().set(states[i])
  );
  tfliteModel.modelRunner.infer();
  states = tfliteModel.modelRunner.getOutputs();
}

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions