openPMD-api
Numpy.hpp
1 /* Copyright 2018-2021 Axel Huebl
2  *
3  * This file is part of openPMD-api.
4  *
5  * openPMD-api is free software: you can redistribute it and/or modify
6  * it under the terms of of either the GNU General Public License or
7  * the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * openPMD-api is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU General Public License and the GNU Lesser General Public License
15  * for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * and the GNU Lesser General Public License along with openPMD-api.
19  * If not, see <http://www.gnu.org/licenses/>.
20  */
21 #pragma once
22 
23 #include "openPMD/Datatype.hpp"
24 #include "openPMD/binding/python/Variant.hpp"
25 
26 #include <pybind11/numpy.h>
27 #include <pybind11/pybind11.h>
28 #include <pybind11/stl.h>
29 
30 #include <exception>
31 #include <string>
32 
33 namespace openPMD
34 {
35 inline Datatype dtype_from_numpy(pybind11::dtype const dt)
36 {
37  // ref: https://docs.scipy.org/doc/numpy/user/basics.types.html
38  // ref: https://github.com/numpy/numpy/issues/10678#issuecomment-369363551
39  if (dt.is(pybind11::dtype("b")))
40  return Datatype::CHAR;
41  else if (dt.is(pybind11::dtype("B")))
42  return Datatype::UCHAR;
43  else if (dt.is(pybind11::dtype("short")))
44  return Datatype::SHORT;
45  else if (dt.is(pybind11::dtype("intc")))
46  return Datatype::INT;
47  else if (dt.is(pybind11::dtype("int_")))
48  return Datatype::LONG;
49  else if (dt.is(pybind11::dtype("longlong")))
50  return Datatype::LONGLONG;
51  else if (dt.is(pybind11::dtype("ushort")))
52  return Datatype::USHORT;
53  else if (dt.is(pybind11::dtype("uintc")))
54  return Datatype::UINT;
55  else if (dt.is(pybind11::dtype("uint")))
56  return Datatype::ULONG;
57  else if (dt.is(pybind11::dtype("ulonglong")))
58  return Datatype::ULONGLONG;
59  else if (dt.is(pybind11::dtype("clongdouble")))
60  return Datatype::CLONG_DOUBLE;
61  else if (dt.is(pybind11::dtype("cdouble")))
62  return Datatype::CDOUBLE;
63  else if (dt.is(pybind11::dtype("csingle")))
64  return Datatype::CFLOAT;
65  else if (dt.is(pybind11::dtype("longdouble")))
66  return Datatype::LONG_DOUBLE;
67  else if (dt.is(pybind11::dtype("double")))
68  return Datatype::DOUBLE;
69  else if (dt.is(pybind11::dtype("single")))
70  return Datatype::FLOAT;
71  else if (dt.is(pybind11::dtype("bool")))
72  return Datatype::BOOL;
73  else
74  {
75  pybind11::print(dt);
76  throw std::runtime_error(
77  "Datatype '...' not known in 'dtype_from_numpy'!"); // _s.format(dt)
78  }
79 }
80 
83 inline Datatype dtype_from_bufferformat(std::string const &fmt)
84 {
85  using DT = Datatype;
86 
87  // refs:
88  // https://docs.scipy.org/doc/numpy-1.15.0/reference/arrays.interface.html
89  // https://docs.python.org/3/library/struct.html#format-characters
90  // std::cout << " scalar type '" << fmt << "'" << std::endl;
91  // typestring: encoding + type + number of bytes
92  if (fmt.find("?") != std::string::npos)
93  return DT::BOOL;
94  else if (fmt.find("b") != std::string::npos)
95  return DT::CHAR;
96  else if (fmt.find("h") != std::string::npos)
97  return DT::SHORT;
98  else if (fmt.find("i") != std::string::npos)
99  return DT::INT;
100  else if (fmt.find("l") != std::string::npos)
101  return DT::LONG;
102  else if (fmt.find("q") != std::string::npos)
103  return DT::LONGLONG;
104  else if (fmt.find("B") != std::string::npos)
105  return DT::UCHAR;
106  else if (fmt.find("H") != std::string::npos)
107  return DT::USHORT;
108  else if (fmt.find("I") != std::string::npos)
109  return DT::UINT;
110  else if (fmt.find("L") != std::string::npos)
111  return DT::ULONG;
112  else if (fmt.find("Q") != std::string::npos)
113  return DT::ULONGLONG;
114  else if (fmt.find("Zf") != std::string::npos)
115  return DT::CFLOAT;
116  else if (fmt.find("Zd") != std::string::npos)
117  return DT::CDOUBLE;
118  else if (fmt.find("Zg") != std::string::npos)
119  return DT::CLONG_DOUBLE;
120  else if (fmt.find("f") != std::string::npos)
121  return DT::FLOAT;
122  else if (fmt.find("d") != std::string::npos)
123  return DT::DOUBLE;
124  else if (fmt.find("g") != std::string::npos)
125  return DT::LONG_DOUBLE;
126  else
127  throw std::runtime_error(
128  "dtype_from_bufferformat: Unknown "
129  "Python type '" +
130  fmt + "'");
131 }
132 
133 inline pybind11::dtype dtype_to_numpy(Datatype const dt)
134 {
135  using DT = Datatype;
136  switch (dt)
137  {
138  case DT::CHAR:
139  case DT::VEC_CHAR:
140  case DT::STRING:
141  case DT::VEC_STRING:
142  return pybind11::dtype("b");
143  break;
144  case DT::UCHAR:
145  case DT::VEC_UCHAR:
146  return pybind11::dtype("B");
147  break;
148  // case DT::SCHAR:
149  // case DT::VEC_SCHAR:
150  // pybind11::dtype("b");
151  // break;
152  case DT::SHORT:
153  case DT::VEC_SHORT:
154  return pybind11::dtype("short");
155  break;
156  case DT::INT:
157  case DT::VEC_INT:
158  return pybind11::dtype("intc");
159  break;
160  case DT::LONG:
161  case DT::VEC_LONG:
162  return pybind11::dtype("int_");
163  break;
164  case DT::LONGLONG:
165  case DT::VEC_LONGLONG:
166  return pybind11::dtype("longlong");
167  break;
168  case DT::USHORT:
169  case DT::VEC_USHORT:
170  return pybind11::dtype("ushort");
171  break;
172  case DT::UINT:
173  case DT::VEC_UINT:
174  return pybind11::dtype("uintc");
175  break;
176  case DT::ULONG:
177  case DT::VEC_ULONG:
178  return pybind11::dtype("uint");
179  break;
180  case DT::ULONGLONG:
181  case DT::VEC_ULONGLONG:
182  return pybind11::dtype("ulonglong");
183  break;
184  case DT::FLOAT:
185  case DT::VEC_FLOAT:
186  return pybind11::dtype("single");
187  break;
188  case DT::DOUBLE:
189  case DT::VEC_DOUBLE:
190  case DT::ARR_DBL_7:
191  return pybind11::dtype("double");
192  break;
193  case DT::LONG_DOUBLE:
194  case DT::VEC_LONG_DOUBLE:
195  return pybind11::dtype("longdouble");
196  break;
197  case DT::CFLOAT:
198  case DT::VEC_CFLOAT:
199  return pybind11::dtype("csingle");
200  break;
201  case DT::CDOUBLE:
202  case DT::VEC_CDOUBLE:
203  return pybind11::dtype("cdouble");
204  break;
205  case DT::CLONG_DOUBLE:
206  case DT::VEC_CLONG_DOUBLE:
207  return pybind11::dtype("clongdouble");
208  break;
209  case DT::BOOL:
210  return pybind11::dtype("bool"); // also "?"
211  break;
212  case DT::DATATYPE:
213  case DT::UNDEFINED:
214  default:
215  throw std::runtime_error(
216  "dtype_to_numpy: Invalid Datatype '{...}'!"); // _s.format(dt)
217  break;
218  }
219 }
220 } // namespace openPMD
Datatype dtype_from_bufferformat(std::string const &fmt)
Return openPMD::Datatype from py::buffer_info::format.
Definition: Numpy.hpp:83
Datatype
Concrete datatype of an object available at runtime.
Definition: Datatype.hpp:45
Public definitions of openPMD-api.
Definition: Date.cpp:28