/** * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ #define INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ #include #include #include #include #include "tensor.h" #include "types.h" #include "ascend_string.h" #include "resource_context.h" #include "ge_error_codes.h" namespace ge { class InferenceContext; using InferenceContextPtr = std::shared_ptr; class ShapeAndTypeImpl; class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { public: ShapeAndType(); ~ShapeAndType() = default; ShapeAndType(const Shape &shape, DataType data_type); void SetShape(const Shape &shape); void SetType(DataType data_type); Shape GetShape() const; DataType GetDataType() const; private: std::shared_ptr shape_and_type_impl_; }; struct InnerInferenceContext; class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { public: ~InferenceContext() = default; InferenceContext(const InferenceContext &context) = delete; InferenceContext(const InferenceContext &&context) = delete; InferenceContext &operator=(const InferenceContext &context) = delete; InferenceContext &operator=(const InferenceContext &&context) = delete; void SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types); const std::vector> &GetInputHandleShapesAndTypes() const; const std::vector> &GetOutputHandleShapesAndTypes() const; void SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types); void SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types); ATTRIBUTED_DEPRECATED(void SetMarks(const std::vector &)) void SetMarks(const std::vector &marks); void SetMarks(const std::vector &marks); ATTRIBUTED_DEPRECATED(void GetMarks(std::vector &) const) const std::vector &GetMarks() const; void GetMarks(std::vector &marks) const; static std::unique_ptr Create(void *resource_context_mgr = nullptr); /** * Get corresponding resource_context by key * For resource op infershape, invoked by op infer_func. * @param key * @return corresponding resource context. Check not null before use it. */ ResourceContext *GetResourceContext(const ge::AscendString &key); /** * Set corresponding resource_context by key. For node which will write to resource. * For resource op infershape, invoked by write_op infer_func. * @param key * @param resource_context pointer. * @return status */ graphStatus SetResourceContext(const ge::AscendString &key, ResourceContext *resource_context); /** * Register resource key relied on. For node which will read from resource. * For resource op infershape, invoked by read_op infer_func. * @param key * @return status */ graphStatus RegisterReliedOnResourceKey(const ge::AscendString &key); /** * During infershape of write op, if resource shape changed, use this to tell. * For resource op infershape, invoked by write_op infer_func. * @param key * @return status */ graphStatus AddChangedResourceKey(const ge::AscendString &key); /** * After read_op infershaped, can get resource_keys relied on. * For resource op infershape, invoked by ge infershape framework. * @param keys * @return status */ const std::set& GetReliedOnResourceKeys() const; /** * After infershape of write op, ge can get resource_key which shape changed. * For resource op infershape, invoked by ge infershape framework. * @return keys */ const std::set& GetChangedResourceKeys() const; /** * After handle changed resource shape, should clear changed_keys in context. * For resource op infershape, invoked by ge infershape framework. */ void ClearChangedResourceKeys(); private: explicit InferenceContext(std::unique_ptr &inner_context); std::shared_ptr inner_inference_context_; }; } // namespace ge #endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_