분명히 모델 초기화 코드를 제대로 패키징 하고 다른 파이썬 파일에서 import 하면 제대로 작동하는 반면, Flask API 에서 적용하려 하면 에러가 나는 경우가 있다

 

This could mean that the variable was uninitialized.
Tensor("mrcnn_detection/Reshape_1:0", shape=(1, 100, 6), dtype=float32) is not an element of this graph.

내가 직면했던 에러는 이 두 가지 에러였다

 

말 그대로 텐서플로의 세션 또는 그래프가 제대로 초기화되지 않아 발생하는 에러이다

 

하지만 왜 Flask API 에 가져오려하면 이런 에러가 발생하는 것인가?

 

Flask 는 여러 스레드를 사용한다. 이런 에러가 발생하는 이유는 tensorflow 모델이 로드되지 않은 채로 스레드에서 사용되기 때문이다

내가 찾은 해결 방법은 tensorflow 가 graph 와 session 을 global 변수로 사용하도록 하는 방법이다

 

 

flask_API.py

import InferenceClass

class InferenceServer:
    def __init__(self):
        self.tensorflow_obj = InferenceClass(self.model_path)
        self.tensorflow_obj.model_init()

        self.app = self.create_flask_app()

    def create_flask_app(self):
        app = Flask(__name__)

        @app.route('/image', methods=['POST'])
        def mask_rcnn_inference():
			
            ...

            inference_image, inference_time = self.tensorflow_obj.inference(img)
            
            ...


        return app

 

 

InferenceClass.model_init()

    def model_init(self):
        self.model = modellib.MaskRCNN(mode="inference", model_dir=self.MODEL_DIR,
                                  config=self.config)
                                  
        self.model.load_weights(self.weights_path, by_name=True)

  

 

 

InferenceClass.inference()

def inference(self, image):
	...
	
	results = self.model.detect([image], verbose=1)

	...

 

 

위 와 같은 형태의 코드에서 Flask 코드에선 global 변수로 session 과 model 을 넣어주면 되고, Flask 에서 import 하는 코드엔 global 로 선언된 graph 를 기반으로 inference 코드가 돌게 하면 된다

 

 

 

 

 

flask_API.py

global model  # 추가됨
global session  # 추가됨


class InferenceServer:
    def __init__(self):
        global model  # 추가됨
        global session  # 추가됨

        session = tf.Session()  # 추가됨
        keras.backend.set_session(session)  # 추가됨

        model = InferenceClass(self.model_path)  # 클래스 변수가 아닌 global 변수 사용
        model.model_init()  # 클래스 변수가 아닌 global 변수 사용
        self.app = self.create_flask_app()

    def create_flask_app(self):
        app = Flask(__name__)

        @app.route('/image', methods=['POST'])
        def mask_rcnn_inference():
            with session.as_default():  # 추가됨
				
                ...
                
                inference_image, inference_time = model.inference(img)  # 클래스 변수가 아닌 global 변수 사용
                
                ...
        return app

 

 

InferenceClass.model_init()

    def model_init(self):
        global graph  # 추가됨
        self.model = modellib.MaskRCNN(mode="inference", model_dir=self.MODEL_DIR,
                                  config=self.config)
        print(self.weights_path)
        self.model.load_weights(self.weights_path, by_name=True)

        graph = tf.get_default_graph()  # 추가됨

 

 

InferenceClass.inference()

def inference(self, image):
	with graph.as_default():  # 추가됨
    
		...
	
		results = self.model.detect([image], verbose=1)

		...
블로그 이미지

우송송

,