Spaces:
Runtime error
Runtime error
| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| ==============================================================================*/ | |
| namespace lstm_object_detection { | |
| namespace tflite { | |
| bool MobileSSDClient::InitializeClient(const protos::ClientOptions& options) { | |
| options_ = options; | |
| return true; | |
| } | |
| bool MobileSSDClient::Detect(const uint8_t* pixels, int width, int height, | |
| int bytes_per_pixel, int bytes_per_row, | |
| protos::DetectionResults* detections) { | |
| SetInputDims(width, height); | |
| // Grayscale input images are only compatible with grayscale models, and | |
| // color input images are only compatible with color models. | |
| CHECK((bytes_per_pixel == 1 && input_depth_ == 1) || | |
| (bytes_per_pixel >= 3 && input_depth_ >= 3)); | |
| if (HasPadding(width, height, bytes_per_pixel, bytes_per_row)) { | |
| std::vector<uint8_t> unpadded_pixels = | |
| RemovePadding(pixels, width, height, bytes_per_pixel, bytes_per_row); | |
| return Detect(&unpadded_pixels[0], detections); | |
| } else { | |
| return Detect(pixels, detections); | |
| } | |
| } | |
| bool MobileSSDClient::Detect(const uint8_t* pixels, | |
| protos::DetectionResults* detections) { | |
| return BatchDetect(pixels, 1, absl::MakeSpan(&detections, 1)); | |
| } | |
| bool MobileSSDClient::BatchDetect( | |
| const uint8_t* pixels, int batch_size, | |
| absl::Span<protos::DetectionResults*> detections) { | |
| if (detections.size() != batch_size) { | |
| LOG(ERROR) << "Batch size does not match output cardinality."; | |
| return false; | |
| } | |
| if (batch_size != batch_size_) { | |
| if (!SetBatchSize(batch_size)) { | |
| LOG(ERROR) << "Couldn't set batch size."; | |
| return false; | |
| } | |
| } | |
| if (!Inference(pixels)) { | |
| LOG(ERROR) << "Couldn't inference."; | |
| return false; | |
| } | |
| for (int batch = 0; batch < batch_size; ++batch) { | |
| if (RequiresPostProcessing()) { | |
| LOG(ERROR) << "Post Processing not supported."; | |
| return false; | |
| } else { | |
| if (!NoPostProcessNoAnchors(detections[batch])) { | |
| LOG(ERROR) << "NoPostProcessNoAnchors failed."; | |
| return false; | |
| } | |
| } | |
| } | |
| return true; | |
| } | |
| bool MobileSSDClient::SetBatchSize(int batch_size) { | |
| batch_size_ = batch_size; | |
| AllocateBuffers(); | |
| if (batch_size != 1) { | |
| LOG(ERROR) | |
| << "Only single batch inference supported by default. All child " | |
| "classes that support batched inference should override this method " | |
| "and not return an error if the batch size is supported. (E.g. " | |
| "MobileSSDTfLiteClient)."; | |
| return false; | |
| } | |
| return true; | |
| } | |
| bool MobileSSDClient::NoPostProcessNoAnchors( | |
| protos::DetectionResults* detections) { | |
| LOG(ERROR) << "not yet implemented"; | |
| return false; | |
| } | |
| bool MobileSSDClient::RequiresPostProcessing() const { | |
| return anchors_.y_size() > 0; | |
| } | |
| void MobileSSDClient::SetInputDims(int width, int height) { | |
| CHECK_EQ(width, input_width_); | |
| CHECK_EQ(height, input_height_); | |
| } | |
| int MobileSSDClient::GetNumberOfLabels() const { return labelmap_.item_size(); } | |
| std::string MobileSSDClient::GetLabelDisplayName(const int class_index) const { | |
| if (class_index < 0 || class_index >= GetNumberOfLabels()) { | |
| return ""; | |
| } | |
| return labelmap_.item(class_index).display_name(); | |
| } | |
| std::string MobileSSDClient::GetLabelName(const int class_index) const { | |
| if (class_index < 0 || class_index >= GetNumberOfLabels()) { | |
| return ""; | |
| } | |
| return labelmap_.item(class_index).name(); | |
| } | |
| int MobileSSDClient::GetLabelId(const int class_index) const { | |
| if (class_index < 0 || class_index >= GetNumberOfLabels() || | |
| !labelmap_.item(class_index).has_id()) { | |
| return -1; | |
| } | |
| return labelmap_.item(class_index).id(); | |
| } | |
| void MobileSSDClient::SetLabelDisplayNameInResults( | |
| protos::DetectionResults* detections) { | |
| for (auto& det : *detections->mutable_detection()) { | |
| for (const auto& class_index : det.class_index()) { | |
| det.add_display_name(GetLabelDisplayName(class_index)); | |
| } | |
| } | |
| } | |
| void MobileSSDClient::SetLabelNameInResults( | |
| protos::DetectionResults* detections) { | |
| for (auto& det : *detections->mutable_detection()) { | |
| for (const auto& class_index : det.class_index()) { | |
| det.add_class_name(GetLabelName(class_index)); | |
| } | |
| } | |
| } | |
| void MobileSSDClient::InitParams(const bool agnostic_mode, | |
| const bool quantize, | |
| const int num_keypoints) { | |
| num_keypoints_ = num_keypoints; | |
| code_size_ = 4 + 2 * num_keypoints; | |
| num_boxes_ = output_locations_size_ / code_size_; | |
| if (agnostic_mode) { | |
| num_classes_ = output_scores_size_ / num_boxes_; | |
| } else { | |
| num_classes_ = (output_scores_size_ / num_boxes_) - 1; | |
| } | |
| quantize_ = quantize; | |
| AllocateBuffers(); | |
| } | |
| void MobileSSDClient::AllocateBuffers() { | |
| // Allocate the output vectors | |
| output_locations_.resize(output_locations_size_ * batch_size_); | |
| output_scores_.resize(output_scores_size_ * batch_size_); | |
| if (quantize_) { | |
| quantized_output_pointers_ = | |
| absl::make_unique<std::vector<std::unique_ptr<std::vector<uint8_t>>>>( | |
| batch_size_ * num_output_layers_ * 2); | |
| for (int batch = 0; batch < batch_size_; ++batch) { | |
| for (int i = 0; i < num_output_layers_; ++i) { | |
| quantized_output_pointers_->at(2 * (i + batch * num_output_layers_)) = | |
| absl::make_unique<std::vector<uint8_t>>(output_locations_sizes_[i]); | |
| quantized_output_pointers_->at(2 * (i + batch * num_output_layers_) + | |
| 1) = | |
| absl::make_unique<std::vector<uint8_t>>(output_scores_sizes_[i]); | |
| } | |
| } | |
| quantized_output_pointers_array_.reset( | |
| new uint8_t*[batch_size_ * num_output_layers_ * 2]); | |
| for (int i = 0; i < batch_size_ * num_output_layers_ * 2; ++i) { | |
| quantized_output_pointers_array_[i] = | |
| quantized_output_pointers_->at(i)->data(); | |
| } | |
| gemm_context_.set_max_num_threads(1); | |
| } else { | |
| output_pointers_[0] = output_locations_.data(); | |
| output_pointers_[1] = output_scores_.data(); | |
| } | |
| } | |
| } // namespace tflite | |
| } // namespace lstm_object_detection | |