boxes_utils.py 657 B

12345678910111213141516171819202122232425262728
  1. import numpy as np
  2. def assert_and_normalize_shape(x, length):
  3. """
  4. Args:
  5. x: ndarray
  6. length: int
  7. """
  8. if x.ndim == 0:
  9. return x
  10. elif x.ndim == 1:
  11. if len(x) == 1:
  12. return x
  13. elif len(x) == length:
  14. return x
  15. else:
  16. raise ValueError('Incompatible shape!')
  17. elif x.ndim == 2:
  18. if x.shape == (1, 1):
  19. return np.squeeze(x, axis=-1)
  20. elif x.shape == (length, 1):
  21. return np.squeeze(x, axis=-1)
  22. else:
  23. raise ValueError('Incompatible shape!')
  24. else:
  25. raise ValueError('Incompatible ndim!')