[Tensorflow] Python Flask 에서 tensorflow 모델사용 시 에러가 나는 경우
Deep Learning/Tensorflow 2020. 7. 20. 11:51분명히 모델 초기화 코드를 제대로 패키징 하고 다른 파이썬 파일에서 import 하면 제대로 작동하는 반면, Flask API 에서 적용하려 하면 에러가 나는 경우가 있다
This could mean that the variable was uninitialized.
Tensor("mrcnn_detection/Reshape_1:0", shape=(1, 100, 6), dtype=float32) is not an element of this graph.
내가 직면했던 에러는 이 두 가지 에러였다
말 그대로 텐서플로의 세션 또는 그래프가 제대로 초기화되지 않아 발생하는 에러이다
하지만 왜 Flask API 에 가져오려하면 이런 에러가 발생하는 것인가?
Flask 는 여러 스레드를 사용한다. 이런 에러가 발생하는 이유는 tensorflow 모델이 로드되지 않은 채로 스레드에서 사용되기 때문이다
내가 찾은 해결 방법은 tensorflow 가 graph 와 session 을 global 변수로 사용하도록 하는 방법이다
flask_API.py
import InferenceClass
class InferenceServer:
def __init__(self):
self.tensorflow_obj = InferenceClass(self.model_path)
self.tensorflow_obj.model_init()
self.app = self.create_flask_app()
def create_flask_app(self):
app = Flask(__name__)
@app.route('/image', methods=['POST'])
def mask_rcnn_inference():
...
inference_image, inference_time = self.tensorflow_obj.inference(img)
...
return app
InferenceClass.model_init()
def model_init(self):
self.model = modellib.MaskRCNN(mode="inference", model_dir=self.MODEL_DIR,
config=self.config)
self.model.load_weights(self.weights_path, by_name=True)
InferenceClass.inference()
def inference(self, image):
...
results = self.model.detect([image], verbose=1)
...
위 와 같은 형태의 코드에서 Flask 코드에선 global 변수로 session 과 model 을 넣어주면 되고, Flask 에서 import 하는 코드엔 global 로 선언된 graph 를 기반으로 inference 코드가 돌게 하면 된다
flask_API.py
global model # 추가됨
global session # 추가됨
class InferenceServer:
def __init__(self):
global model # 추가됨
global session # 추가됨
session = tf.Session() # 추가됨
keras.backend.set_session(session) # 추가됨
model = InferenceClass(self.model_path) # 클래스 변수가 아닌 global 변수 사용
model.model_init() # 클래스 변수가 아닌 global 변수 사용
self.app = self.create_flask_app()
def create_flask_app(self):
app = Flask(__name__)
@app.route('/image', methods=['POST'])
def mask_rcnn_inference():
with session.as_default(): # 추가됨
...
inference_image, inference_time = model.inference(img) # 클래스 변수가 아닌 global 변수 사용
...
return app
InferenceClass.model_init()
def model_init(self):
global graph # 추가됨
self.model = modellib.MaskRCNN(mode="inference", model_dir=self.MODEL_DIR,
config=self.config)
print(self.weights_path)
self.model.load_weights(self.weights_path, by_name=True)
graph = tf.get_default_graph() # 추가됨
InferenceClass.inference()
def inference(self, image):
with graph.as_default(): # 추가됨
...
results = self.model.detect([image], verbose=1)
...
'Deep Learning > Tensorflow' 카테고리의 다른 글
[tensorflow] Checkpoint load 과정에서 No device assignments 오류가 날 때 해결방법 (0) | 2020.04.16 |
---|---|
[tensorflow] serialized_options error 해결 방법 (1) | 2019.10.29 |
[tensorflow] ModuleNotFoundError: No module named 'object_detection' 해결 방법 (0) | 2019.10.28 |