openPMD-api
Numpy.hpp
1 /* Copyright 2018-2021 Axel Huebl
2  *
3  * This file is part of openPMD-api.
4  *
5  * openPMD-api is free software: you can redistribute it and/or modify
6  * it under the terms of of either the GNU General Public License or
7  * the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * openPMD-api is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU General Public License and the GNU Lesser General Public License
15  * for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * and the GNU Lesser General Public License along with openPMD-api.
19  * If not, see <http://www.gnu.org/licenses/>.
20  */
21 #pragma once
22 
23 #include "openPMD/Datatype.hpp"
24 
25 #include <pybind11/numpy.h>
26 #include <pybind11/pybind11.h>
27 #include <pybind11/stl.h>
28 
29 #include <exception>
30 #include <string>
31 #include <type_traits>
32 
33 namespace openPMD
34 {
35 inline Datatype dtype_from_numpy(pybind11::dtype const dt)
36 {
37  // ref: https://docs.scipy.org/doc/numpy/user/basics.types.html
38  // ref: https://github.com/numpy/numpy/issues/10678#issuecomment-369363551
39  if (dt.char_() == pybind11::dtype("b").char_())
40  if constexpr (std::is_signed_v<char>)
41  {
42  return Datatype::CHAR;
43  }
44  else
45  {
46  return Datatype::SCHAR;
47  }
48  else if (dt.char_() == pybind11::dtype("B").char_())
49  if constexpr (std::is_unsigned_v<char>)
50  {
51  return Datatype::CHAR;
52  }
53  else
54  {
55  return Datatype::UCHAR;
56  }
57  else if (dt.char_() == pybind11::dtype("short").char_())
58  return Datatype::SHORT;
59  else if (dt.char_() == pybind11::dtype("intc").char_())
60  return Datatype::INT;
61  else if (dt.char_() == pybind11::dtype("int_").char_())
62  return Datatype::LONG;
63  else if (dt.char_() == pybind11::dtype("longlong").char_())
64  return Datatype::LONGLONG;
65  else if (dt.char_() == pybind11::dtype("ushort").char_())
66  return Datatype::USHORT;
67  else if (dt.char_() == pybind11::dtype("uintc").char_())
68  return Datatype::UINT;
69  else if (dt.char_() == pybind11::dtype("uint").char_())
70  return Datatype::ULONG;
71  else if (dt.char_() == pybind11::dtype("ulonglong").char_())
72  return Datatype::ULONGLONG;
73  else if (dt.char_() == pybind11::dtype("clongdouble").char_())
74  return Datatype::CLONG_DOUBLE;
75  else if (dt.char_() == pybind11::dtype("cdouble").char_())
76  return Datatype::CDOUBLE;
77  else if (dt.char_() == pybind11::dtype("csingle").char_())
78  return Datatype::CFLOAT;
79  else if (dt.char_() == pybind11::dtype("longdouble").char_())
80  return Datatype::LONG_DOUBLE;
81  else if (dt.char_() == pybind11::dtype("double").char_())
82  return Datatype::DOUBLE;
83  else if (dt.char_() == pybind11::dtype("single").char_())
84  return Datatype::FLOAT;
85  else if (dt.char_() == pybind11::dtype("bool").char_())
86  return Datatype::BOOL;
87  else
88  {
89  pybind11::print(dt);
90  throw std::runtime_error(
91  std::string("Datatype '") + dt.char_() +
92  std::string("' not known in 'dtype_from_numpy'!")); // _s.format(dt)
93  }
94 }
95 
98 inline Datatype dtype_from_bufferformat(std::string const &fmt)
99 {
100  using DT = Datatype;
101 
102  // refs:
103  // https://docs.scipy.org/doc/numpy-1.15.0/reference/arrays.interface.html
104  // https://docs.python.org/3/library/struct.html#format-characters
105  // std::cout << " scalar type '" << fmt << "'" << std::endl;
106  // typestring: encoding + type + number of bytes
107  if (fmt.find("?") != std::string::npos)
108  return DT::BOOL;
109  else if (fmt.find("b") != std::string::npos)
110  return DT::CHAR;
111  else if (fmt.find("h") != std::string::npos)
112  return DT::SHORT;
113  else if (fmt.find("i") != std::string::npos)
114  return DT::INT;
115  else if (fmt.find("l") != std::string::npos)
116  return DT::LONG;
117  else if (fmt.find("q") != std::string::npos)
118  return DT::LONGLONG;
119  else if (fmt.find("B") != std::string::npos)
120  return DT::UCHAR;
121  else if (fmt.find("H") != std::string::npos)
122  return DT::USHORT;
123  else if (fmt.find("I") != std::string::npos)
124  return DT::UINT;
125  else if (fmt.find("L") != std::string::npos)
126  return DT::ULONG;
127  else if (fmt.find("Q") != std::string::npos)
128  return DT::ULONGLONG;
129  else if (fmt.find("Zf") != std::string::npos)
130  return DT::CFLOAT;
131  else if (fmt.find("Zd") != std::string::npos)
132  return DT::CDOUBLE;
133  else if (fmt.find("Zg") != std::string::npos)
134  return DT::CLONG_DOUBLE;
135  else if (fmt.find("f") != std::string::npos)
136  return DT::FLOAT;
137  else if (fmt.find("d") != std::string::npos)
138  return DT::DOUBLE;
139  else if (fmt.find("g") != std::string::npos)
140  return DT::LONG_DOUBLE;
141  else
142  throw std::runtime_error(
143  "dtype_from_bufferformat: Unknown "
144  "Python type '" +
145  fmt + "'");
146 }
147 
148 inline pybind11::dtype dtype_to_numpy(Datatype const dt)
149 {
150  using DT = Datatype;
151  switch (dt)
152  {
153  case DT::CHAR:
154  case DT::VEC_CHAR:
155  case DT::SCHAR:
156  case DT::VEC_SCHAR:
157  case DT::STRING:
158  case DT::VEC_STRING:
159  return pybind11::dtype("b");
160  break;
161  case DT::UCHAR:
162  case DT::VEC_UCHAR:
163  return pybind11::dtype("B");
164  break;
165  // case DT::SCHAR:
166  // case DT::VEC_SCHAR:
167  // pybind11::dtype("b");
168  // break;
169  case DT::SHORT:
170  case DT::VEC_SHORT:
171  return pybind11::dtype("short");
172  break;
173  case DT::INT:
174  case DT::VEC_INT:
175  return pybind11::dtype("intc");
176  break;
177  case DT::LONG:
178  case DT::VEC_LONG:
179  return pybind11::dtype("int_");
180  break;
181  case DT::LONGLONG:
182  case DT::VEC_LONGLONG:
183  return pybind11::dtype("longlong");
184  break;
185  case DT::USHORT:
186  case DT::VEC_USHORT:
187  return pybind11::dtype("ushort");
188  break;
189  case DT::UINT:
190  case DT::VEC_UINT:
191  return pybind11::dtype("uintc");
192  break;
193  case DT::ULONG:
194  case DT::VEC_ULONG:
195  return pybind11::dtype("uint");
196  break;
197  case DT::ULONGLONG:
198  case DT::VEC_ULONGLONG:
199  return pybind11::dtype("ulonglong");
200  break;
201  case DT::FLOAT:
202  case DT::VEC_FLOAT:
203  return pybind11::dtype("single");
204  break;
205  case DT::DOUBLE:
206  case DT::VEC_DOUBLE:
207  case DT::ARR_DBL_7:
208  return pybind11::dtype("double");
209  break;
210  case DT::LONG_DOUBLE:
211  case DT::VEC_LONG_DOUBLE:
212  return pybind11::dtype("longdouble");
213  break;
214  case DT::CFLOAT:
215  case DT::VEC_CFLOAT:
216  return pybind11::dtype("csingle");
217  break;
218  case DT::CDOUBLE:
219  case DT::VEC_CDOUBLE:
220  return pybind11::dtype("cdouble");
221  break;
222  case DT::CLONG_DOUBLE:
223  case DT::VEC_CLONG_DOUBLE:
224  return pybind11::dtype("clongdouble");
225  break;
226  case DT::BOOL:
227  return pybind11::dtype("bool"); // also "?"
228  break;
229  case DT::UNDEFINED:
230  default:
231  throw std::runtime_error(
232  "dtype_to_numpy: Invalid Datatype '{...}'!"); // _s.format(dt)
233  break;
234  }
235 }
236 } // namespace openPMD
Datatype dtype_from_bufferformat(std::string const &fmt)
Return openPMD::Datatype from py::buffer_info::format.
Definition: Numpy.hpp:98
Datatype
Concrete datatype of an object available at runtime.
Definition: Datatype.hpp:45
Public definitions of openPMD-api.