diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 054094d..210671e 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -294,12 +294,24 @@ def init_inference(self, frame=None, **kwargs): graph = finalize_graph(graph_def) output_nodes = get_output_nodes(graph) output_nodes = [on.replace("DLC/", "") for on in output_nodes] - converter = tf.lite.TFLiteConverter.from_frozen_graph( - model_file, - ["Placeholder"], - output_nodes, - input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, - ) + + tf_version_2 = tf.__version__[0] == '2' + + if tf_version_2: + converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( + model_file, + ["Placeholder"], + output_nodes, + input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, + ) + else: + converter = tf.lite.TFLiteConverter.from_frozen_graph( + model_file, + ["Placeholder"], + output_nodes, + input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, + ) + try: tflite_model = converter.convert() except Exception: