Coverage for pydaconf/main.py: 99.03%

146 statements  

« prev     ^ index     » next       coverage.py v7.6.11, created at 2025-02-16 17:46 +0000

1import logging 

2import re 

3import threading 

4from collections.abc import Callable 

5from functools import partial 

6from typing import Generic, TypeAlias, TypeVar, get_args 

7 

8from pydaconf.plugins.base import PluginBase 

9from pydaconf.utils.exceptions import ProviderException 

10from pydaconf.utils.file import load_config_file, load_from_url 

11from pydaconf.utils.interpolation import ( 

12 has_interpolation_template, 

13 interpolate_template, 

14) 

15from pydaconf.utils.plugins import load_builtin_plugins, load_dynamic_plugins 

16from pydantic import BaseModel, ValidationError 

17 

18T = TypeVar("T", bound=BaseModel) 

19ConfigValueType: TypeAlias = list | dict | str | int | bool | None 

20 

21class PydaConf(Generic[T]): 

22 

23 def __init__(self) -> None: 

24 self._raw_config: dict | None = None 

25 self._config: T | None = None 

26 self._plugins: dict[str, PluginBase] = {} 

27 self._update_subscribers: dict[str, list[Callable[[str, str], None]]] = {} 

28 self._update_lock = threading.Lock() 

29 self.logger = logging.getLogger(__name__) 

30 

31 def from_file(self, file_path: str) -> None: 

32 self._load_plugins() 

33 self.logger.debug(f"Load config from file_path '{file_path}'") 

34 self._raw_config = load_config_file(file_path) 

35 

36 def from_url(self, url: str) -> None: 

37 self._load_plugins() 

38 self.logger.debug(f"Load config from url '{url}'") 

39 self._raw_config = load_from_url(url) 

40 

41 def from_dict(self, dict_data: dict) -> None: 

42 self._load_plugins() 

43 self.logger.debug("Load config from dict") 

44 self._raw_config = dict_data 

45 

46 @property 

47 def config(self) -> T: 

48 self.logger.debug("Check if provider is initialized") 

49 if self._raw_config is None: 

50 raise ProviderException("""PydaConf is not initialized.  

51 You need to run on of these methods first from_file('file_path'), from_dict(dict_data) or from_url(url).""") 

52 

53 if self._config is None: 

54 config_copy = self._raw_config.copy() 

55 self.logger.debug("Inject credentials to the config") 

56 self._inject_secrets(config_copy) 

57 self.logger.debug("Interpolate values") 

58 self._interpolate_templates(config_copy, config_copy) 

59 

60 try: 

61 self.logger.debug("Build config object") 

62 config: T = self._get_generic_type()(**config_copy) 

63 self._config = config 

64 self.logger.debug("Config object was build successfully") 

65 except ValidationError as e: 

66 raise ProviderException('Configuration file validation failed with errors', e.errors()) from e 

67 

68 return self._config 

69 

70 def register_plugin(self, plugin_class: type[PluginBase]) -> None: 

71 """ Manually register plugin """ 

72 

73 self.logger.debug(f"Register plugin for '{plugin_class}'") 

74 self._plugins[str(plugin_class.PREFIX)] = plugin_class() 

75 

76 def on_update(self, key_pattern: str, callback: Callable[[str, str], None]) -> None: 

77 """Subscribe to an update events of specific pattern.""" 

78 self.logger.debug(f"Register on_update for key_patter '{key_pattern}'") 

79 self._update_subscribers.setdefault(key_pattern, []) 

80 self._update_subscribers[key_pattern].append(callback) 

81 

82 def _get_generic_type(self) -> type[T]: 

83 """Return the type of generic T """ 

84 

85 self.logger.debug("Check the type of the generic") 

86 

87 # This is a bit hacky method since __orig_class__ is not well documented and could be changed in the future... 

88 orig_class = getattr(self, '__orig_class__', None) 

89 if orig_class is None: 

90 raise ProviderException('PydaConf must be defined as generic Config[MyPydanticType]()') 

91 

92 generic_type: type[T] | None = next(iter(get_args(orig_class)), None) 

93 if generic_type is None or not issubclass(generic_type, BaseModel): 

94 raise ProviderException('Generic type must inherit pydantic BaseModel class Config[MyPydanticType]()') 

95 

96 return generic_type 

97 

98 def _load_plugins(self) -> None: 

99 self.logger.debug("Load builtin plugins") 

100 self._load_builtin_plugins() 

101 self.logger.debug("Load dynamic plugins") 

102 self._load_dynamic_plugins() 

103 

104 def _load_builtin_plugins(self) -> None: 

105 for plugin in load_builtin_plugins(): 

106 self._plugins[plugin.PREFIX] = plugin 

107 

108 def _load_dynamic_plugins(self) -> None: 

109 for plugin in load_dynamic_plugins(): 

110 self._plugins[plugin.PREFIX] = plugin 

111 

112 def _interpolate_templates(self, node: ConfigValueType, config_data: dict) -> None: 

113 if isinstance(node, list): 

114 for index, element in enumerate(node): 

115 if type(element) is str: 

116 node[index] = interpolate_template(element, config_data) 

117 else: 

118 self._interpolate_templates(element, config_data) 

119 

120 elif isinstance(node, dict): 120 ↛ exitline 120 didn't return from function '_interpolate_templates' because the condition on line 120 was always true

121 for key, element in node.items(): 

122 if isinstance(element, dict) or isinstance(element, list): 

123 self._interpolate_templates(element, config_data) 

124 

125 elif isinstance(element, str): 

126 if has_interpolation_template(element): 

127 node[key] = interpolate_template(element, config_data) 

128 

129 

130 def _match_and_execute_plugin(self, element: str, key: str) -> str: 

131 match_prefix = re.match(r'(?P<PLUGIN_PREFIX>[^:]+):///(?P<VALUE>[^\s]+)', element) 

132 if match_prefix: 

133 plugin_prefix = match_prefix.groupdict()['PLUGIN_PREFIX'] 

134 value = match_prefix.groupdict()['VALUE'] 

135 plugin = self._plugins.get(plugin_prefix.upper()) 

136 if plugin is None: 

137 raise ProviderException(f"Plugin with prefix '{plugin_prefix}' is not registered") 

138 return plugin._execute_plugin(value, partial(self._on_update, key)) 

139 else: 

140 return element 

141 

142 def _inject_secrets(self, node: ConfigValueType, key_path: str="") -> None: 

143 if isinstance(node, list): 

144 for index, element in enumerate(node): 

145 if type(element) is str: 

146 node[index] = self._match_and_execute_plugin(element, f"{key_path}[{index}]") 

147 else: 

148 self._inject_secrets(element, key_path=f"{key_path}[{index}]") 

149 

150 elif isinstance(node, dict): 150 ↛ exitline 150 didn't return from function '_inject_secrets' because the condition on line 150 was always true

151 for key, element in node.items(): 

152 if isinstance(element, dict) or isinstance(element, list): 

153 self._inject_secrets(element, key_path=f"{key_path}.{key}") 

154 

155 elif isinstance(element, str): 

156 node[key] = self._match_and_execute_plugin(element, f"{key_path}.{key}") 

157 

158 def _update_config(self, key: str, value: str) -> None: 

159 """ Update the configuration base on key and value """ 

160 

161 # Unfortunately, I could find a better way to do this directly on the pyndatic model 

162 # TODO: Research a better option with setattr and getattr 

163 if self._config is None: 

164 raise ProviderException('PydaConf is not initialized, or you call on_update callback in the plugin run.') 

165 

166 config_model = self._config.model_dump() 

167 current = config_model 

168 keys = re.split(r'\.(?![^\[]*\])', key.lstrip('.')) # Split while ignoring dots inside brackets 

169 

170 for key in keys[:-1]: # Traverse until the second last key 

171 match = re.match(r'(\w+)\[(\d+)\]', key) # Match list indexing pattern, e.g., users[0] 

172 

173 if match: 

174 key, index = match.groups() 

175 index = int(index) 

176 current = current[key][index] # Move to the specific index 

177 else: 

178 current = current[key] # Move deeper 

179 

180 # Process the last key 

181 final_key = keys[-1] 

182 match = re.match(r'(\w+)\[(\d+)\]', final_key) 

183 

184 if match: 

185 key, index = match.groups() 

186 index = int(index) 

187 current[key][index] = value # Update value at the index 

188 else: 

189 current[final_key] = value # Update the final key 

190 

191 self._config = self._get_generic_type()(**config_model) 

192 

193 

194 def _on_update(self, key: str, value: str) -> None: 

195 """ Update the configuration and notify all subscribers registered for specific key pattern """ 

196 

197 self.logger.debug(f"Call _on_update for '{key}' and value '{value}'") 

198 

199 # We use thread lock to protection against race conditions when threads access shared objects. 

200 with self._update_lock: 

201 self._update_config(key, value) 

202 

203 for key_patters, subscribers in self._update_subscribers.items(): 

204 if re.match(key_patters, key): 

205 for subscriber in subscribers: 

206 subscriber(key, value) 

207 

208 

209 def __repr__(self) -> str: 

210 return f'{self.__class__.__name__}({self.config.model_dump()})'