tensorflow 常用代码片段

2023-12-13 03:31:55

tensorflow 常用代码片段

  • 加载pb文件
    def load_graph(frozen_graph_filename):
    with tf.io.gfile.GFile(frozen_graph_filename, “rb”) as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def)
    return graph

  • 写graph_def 到文件

with tf.gfile.Gfile('simplified.pb', 'wb') as fid:
    fid.write(graph_def.SerializeToString())
  • 获取图的输入输出节点名字
def analyze_inputs_outputs(graph):
    ops = graph.get_operations()
    outputs_set = set(ops)
    inputs = []
    for op in ops:
        if len(op.inputs) == 0 and op.type != 'Const':
            inputs.append(op)
        else:
            for input_tensor in op.inputs:
                if input_tensor.op in outputs_set:
                    outputs_set.remove(input_tensor.op)
    outputs = list(outputs_set)
    return (inputs, outputs)
  • 简化模型,删除训练节点
graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [out_name])
graph_def = tf.graph_util.remove_training_nodes(graph_def)

文章来源:https://blog.csdn.net/lyyiangang/article/details/134882889
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。