기타

tensorflow retracing warning 없애기

_tera_ 2023. 8. 9. 17:57

워닝 문구 : retracing 하면서 컴퓨팅 낭비중이니 tf.function을 사용해서 reduce_tracing True 옵션 줘라

WARNING:tensorflow:5 out of the last 5 calls to <function loss_fn at 0x7f0070ef0e18> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has reduce_retracing=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

 

tf.function(reduce_retracing) 선언하고 def의 데코레이터로 @tf.function 입력하면 해당 def가 tf 모델처럼 작동한다

def 안에서 roop나 if 문에서 반복작동시에 retracing 하지 말라고 옵션 주는 것인데 나는 학습이 아니라 인퍼런스 부분을 구현하고 있기 때문에 model을 load해서 predict()로 예측을 하는 def를 만듦.

tf.function 보다 model.predict()가 더 상위여서 tf.function을 쓰려면 model과 model.predict를 해당 def 밖에서 수행해야함

또 tf-function can only compute the input for a fixed size. 라는 글을 봤는데 모델들이 각각 인풋 모양이 달라서 tf.function을 쓸 수는 없을 것 같음

복잡해져서 더 찾아보니

model.predict() 말고 model()을 쓰면 없어진단다. 대박 진짜 없어짐. 메서드 내부구조 차이겠지

그리고 인퍼런스도 빨라졌다 

차이점 gpt에 물어봄

  1. Usage Context:
    • model.predict(): This is a method specifically designed for obtaining predictions from the model. It's optimized for batch processing and can be used with large datasets. This method performs inference in a stateless manner, ensuring that internal states (like those in RNNs) do not persist across multiple calls.
    • model(): This directly calls the model's __call__ method, treating the model as a callable function. This is more of a low-level approach, often used during training or when you want to build custom training loops or architectures.
  2. Input Handling:
    • model.predict(): Handles batches of data well and can process the data more efficiently when predicting on large datasets. It also returns predictions in a numpy array format.
    • model(): Requires you to manage the batch dimensions and other input shapes explicitly. The output is a TensorFlow tensor.
  3. Stateful Models:
    • model.predict(): Resets the states of stateful models between calls. This means if you're working with stateful RNNs, the state does not persist across predict calls.
    • model(): Does not automatically reset the states. If you're working with stateful RNNs, the state will persist, which might be desired in certain scenarios.
  4. Performance and Overhead:
    • model.predict(): Contains some overhead because it's designed for safety and general-purpose predictions. As a result, for a very small number of predictions, it might be slower than directly calling the model.
    • model(): Might be faster for a one-off prediction since it doesn't have the overhead of predict(). However, it's not optimized for large batches or datasets.
  5. Custom Models and Subclassing:
    • If you're building custom models by subclassing tf.keras.Model, you'll often override the call method, which is what gets invoked when you use model().

In general, for most typical use-cases where you just want to get predictions from a trained model, model.predict() is recommended due to its ease of use and built-in optimizations. However, if you're working on building custom architectures, training loops, or need more fine-grained control, directly calling the model using model() can be useful.

 

지금 구현하고 있는 게 대량의 모델로 실시간 인퍼런스하는 시스템이라 스레드를 추가해놨는데 스레드에서 predict의 state를 retracing하면서 overhead로 의심되는게 발생한다는 경고로 추정중.

데이터셋이랑 모델이 가벼워서 현재는 model()로 돌리면 충분할 것 같다.