295 lines
11 KiB
Python
295 lines
11 KiB
Python
"""
|
|
Module is used to infer Django model fields.
|
|
"""
|
|
from jedi._compatibility import Parameter
|
|
from jedi import debug
|
|
from jedi.inference.cache import inference_state_function_cache
|
|
from jedi.inference.base_value import ValueSet, iterator_to_value_set, ValueWrapper
|
|
from jedi.inference.filters import DictFilter, AttributeOverwrite
|
|
from jedi.inference.names import NameWrapper, BaseTreeParamName
|
|
from jedi.inference.compiled.value import EmptyCompiledName
|
|
from jedi.inference.value.instance import TreeInstance
|
|
from jedi.inference.value.klass import ClassMixin
|
|
from jedi.inference.gradual.base import GenericClass
|
|
from jedi.inference.gradual.generics import TupleGenericManager
|
|
from jedi.inference.signature import AbstractSignature
|
|
|
|
|
|
mapping = {
|
|
'IntegerField': (None, 'int'),
|
|
'BigIntegerField': (None, 'int'),
|
|
'PositiveIntegerField': (None, 'int'),
|
|
'SmallIntegerField': (None, 'int'),
|
|
'CharField': (None, 'str'),
|
|
'TextField': (None, 'str'),
|
|
'EmailField': (None, 'str'),
|
|
'GenericIPAddressField': (None, 'str'),
|
|
'URLField': (None, 'str'),
|
|
'FloatField': (None, 'float'),
|
|
'BinaryField': (None, 'bytes'),
|
|
'BooleanField': (None, 'bool'),
|
|
'DecimalField': ('decimal', 'Decimal'),
|
|
'TimeField': ('datetime', 'time'),
|
|
'DurationField': ('datetime', 'timedelta'),
|
|
'DateField': ('datetime', 'date'),
|
|
'DateTimeField': ('datetime', 'datetime'),
|
|
'UUIDField': ('uuid', 'UUID'),
|
|
}
|
|
|
|
_FILTER_LIKE_METHODS = ('create', 'filter', 'exclude', 'update', 'get',
|
|
'get_or_create', 'update_or_create')
|
|
|
|
|
|
@inference_state_function_cache()
|
|
def _get_deferred_attributes(inference_state):
|
|
return inference_state.import_module(
|
|
('django', 'db', 'models', 'query_utils')
|
|
).py__getattribute__('DeferredAttribute').execute_annotation()
|
|
|
|
|
|
def _infer_scalar_field(inference_state, field_name, field_tree_instance, is_instance):
|
|
try:
|
|
module_name, attribute_name = mapping[field_tree_instance.py__name__()]
|
|
except KeyError:
|
|
return None
|
|
|
|
if not is_instance:
|
|
return _get_deferred_attributes(inference_state)
|
|
|
|
if module_name is None:
|
|
module = inference_state.builtins_module
|
|
else:
|
|
module = inference_state.import_module((module_name,))
|
|
|
|
for attribute in module.py__getattribute__(attribute_name):
|
|
return attribute.execute_with_values()
|
|
|
|
|
|
@iterator_to_value_set
|
|
def _get_foreign_key_values(cls, field_tree_instance):
|
|
if isinstance(field_tree_instance, TreeInstance):
|
|
# TODO private access..
|
|
argument_iterator = field_tree_instance._arguments.unpack()
|
|
key, lazy_values = next(argument_iterator, (None, None))
|
|
if key is None and lazy_values is not None:
|
|
for value in lazy_values.infer():
|
|
if value.py__name__() == 'str':
|
|
foreign_key_class_name = value.get_safe_value()
|
|
module = cls.get_root_context()
|
|
for v in module.py__getattribute__(foreign_key_class_name):
|
|
if v.is_class():
|
|
yield v
|
|
elif value.is_class():
|
|
yield value
|
|
|
|
|
|
def _infer_field(cls, field_name, is_instance):
|
|
inference_state = cls.inference_state
|
|
result = field_name.infer()
|
|
for field_tree_instance in result:
|
|
scalar_field = _infer_scalar_field(
|
|
inference_state, field_name, field_tree_instance, is_instance)
|
|
if scalar_field is not None:
|
|
return scalar_field
|
|
|
|
name = field_tree_instance.py__name__()
|
|
is_many_to_many = name == 'ManyToManyField'
|
|
if name in ('ForeignKey', 'OneToOneField') or is_many_to_many:
|
|
if not is_instance:
|
|
return _get_deferred_attributes(inference_state)
|
|
|
|
values = _get_foreign_key_values(cls, field_tree_instance)
|
|
if is_many_to_many:
|
|
return ValueSet(filter(None, [
|
|
_create_manager_for(v, 'RelatedManager') for v in values
|
|
]))
|
|
else:
|
|
return values.execute_with_values()
|
|
|
|
debug.dbg('django plugin: fail to infer `%s` from class `%s`',
|
|
field_name.string_name, cls.py__name__())
|
|
return result
|
|
|
|
|
|
class DjangoModelName(NameWrapper):
|
|
def __init__(self, cls, name, is_instance):
|
|
super(DjangoModelName, self).__init__(name)
|
|
self._cls = cls
|
|
self._is_instance = is_instance
|
|
|
|
def infer(self):
|
|
return _infer_field(self._cls, self._wrapped_name, self._is_instance)
|
|
|
|
|
|
def _create_manager_for(cls, manager_cls='BaseManager'):
|
|
managers = cls.inference_state.import_module(
|
|
('django', 'db', 'models', 'manager')
|
|
).py__getattribute__(manager_cls)
|
|
for m in managers:
|
|
if m.is_class_mixin():
|
|
generics_manager = TupleGenericManager((ValueSet([cls]),))
|
|
for c in GenericClass(m, generics_manager).execute_annotation():
|
|
return c
|
|
return None
|
|
|
|
|
|
def _new_dict_filter(cls, is_instance):
|
|
filters = list(cls.get_filters(
|
|
is_instance=is_instance,
|
|
include_metaclasses=False,
|
|
include_type_when_class=False)
|
|
)
|
|
dct = {
|
|
name.string_name: DjangoModelName(cls, name, is_instance)
|
|
for filter_ in reversed(filters)
|
|
for name in filter_.values()
|
|
}
|
|
if is_instance:
|
|
# Replace the objects with a name that amounts to nothing when accessed
|
|
# in an instance. This is not perfect and still completes "objects" in
|
|
# that case, but it at least not inferes stuff like `.objects.filter`.
|
|
# It would be nicer to do that in a better way, so that it also doesn't
|
|
# show up in completions, but it's probably just not worth doing that
|
|
# for the extra amount of work.
|
|
dct['objects'] = EmptyCompiledName(cls.inference_state, 'objects')
|
|
|
|
return DictFilter(dct)
|
|
|
|
|
|
def is_django_model_base(value):
|
|
return value.py__name__() == 'ModelBase' \
|
|
and value.get_root_context().py__name__() == 'django.db.models.base'
|
|
|
|
|
|
def get_metaclass_filters(func):
|
|
def wrapper(cls, metaclasses, is_instance):
|
|
for metaclass in metaclasses:
|
|
if is_django_model_base(metaclass):
|
|
return [_new_dict_filter(cls, is_instance)]
|
|
|
|
return func(cls, metaclasses, is_instance)
|
|
return wrapper
|
|
|
|
|
|
def tree_name_to_values(func):
|
|
def wrapper(inference_state, context, tree_name):
|
|
result = func(inference_state, context, tree_name)
|
|
if tree_name.value in _FILTER_LIKE_METHODS:
|
|
# Here we try to overwrite stuff like User.objects.filter. We need
|
|
# this to make sure that keyword param completion works on these
|
|
# kind of methods.
|
|
for v in result:
|
|
if v.get_qualified_names() == ('_BaseQuerySet', tree_name.value) \
|
|
and v.parent_context.is_module() \
|
|
and v.parent_context.py__name__() == 'django.db.models.query':
|
|
qs = context.get_value()
|
|
generics = qs.get_generics()
|
|
if len(generics) >= 1:
|
|
return ValueSet(QuerySetMethodWrapper(v, model)
|
|
for model in generics[0])
|
|
|
|
elif tree_name.value == 'BaseManager' and context.is_module() \
|
|
and context.py__name__() == 'django.db.models.manager':
|
|
return ValueSet(ManagerWrapper(r) for r in result)
|
|
|
|
elif tree_name.value == 'Field' and context.is_module() \
|
|
and context.py__name__() == 'django.db.models.fields':
|
|
return ValueSet(FieldWrapper(r) for r in result)
|
|
return result
|
|
return wrapper
|
|
|
|
|
|
def _find_fields(cls):
|
|
for name in _new_dict_filter(cls, is_instance=False).values():
|
|
for value in name.infer():
|
|
if value.name.get_qualified_names(include_module_names=True) \
|
|
== ('django', 'db', 'models', 'query_utils', 'DeferredAttribute'):
|
|
yield name
|
|
|
|
|
|
def _get_signatures(cls):
|
|
return [DjangoModelSignature(cls, field_names=list(_find_fields(cls)))]
|
|
|
|
|
|
def get_metaclass_signatures(func):
|
|
def wrapper(cls, metaclasses):
|
|
for metaclass in metaclasses:
|
|
if is_django_model_base(metaclass):
|
|
return _get_signatures(cls)
|
|
return func(cls, metaclass)
|
|
return wrapper
|
|
|
|
|
|
class ManagerWrapper(ValueWrapper):
|
|
def py__getitem__(self, index_value_set, contextualized_node):
|
|
return ValueSet(
|
|
GenericManagerWrapper(generic)
|
|
for generic in self._wrapped_value.py__getitem__(
|
|
index_value_set, contextualized_node)
|
|
)
|
|
|
|
|
|
class GenericManagerWrapper(AttributeOverwrite, ClassMixin):
|
|
def py__get__on_class(self, calling_instance, instance, class_value):
|
|
return calling_instance.class_value.with_generics(
|
|
(ValueSet({class_value}),)
|
|
).py__call__(calling_instance._arguments)
|
|
|
|
def with_generics(self, generics_tuple):
|
|
return self._wrapped_value.with_generics(generics_tuple)
|
|
|
|
|
|
class FieldWrapper(ValueWrapper):
|
|
def py__getitem__(self, index_value_set, contextualized_node):
|
|
return ValueSet(
|
|
GenericFieldWrapper(generic)
|
|
for generic in self._wrapped_value.py__getitem__(
|
|
index_value_set, contextualized_node)
|
|
)
|
|
|
|
|
|
class GenericFieldWrapper(AttributeOverwrite, ClassMixin):
|
|
def py__get__on_class(self, calling_instance, instance, class_value):
|
|
# This is mostly an optimization to avoid Jedi aborting inference,
|
|
# because of too many function executions of Field.__get__.
|
|
return ValueSet({calling_instance})
|
|
|
|
|
|
class DjangoModelSignature(AbstractSignature):
|
|
def __init__(self, value, field_names):
|
|
super(DjangoModelSignature, self).__init__(value)
|
|
self._field_names = field_names
|
|
|
|
def get_param_names(self, resolve_stars=False):
|
|
return [DjangoParamName(name) for name in self._field_names]
|
|
|
|
|
|
class DjangoParamName(BaseTreeParamName):
|
|
def __init__(self, field_name):
|
|
super(DjangoParamName, self).__init__(field_name.parent_context, field_name.tree_name)
|
|
self._field_name = field_name
|
|
|
|
def get_kind(self):
|
|
return Parameter.KEYWORD_ONLY
|
|
|
|
def infer(self):
|
|
return self._field_name.infer()
|
|
|
|
|
|
class QuerySetMethodWrapper(ValueWrapper):
|
|
def __init__(self, method, model_cls):
|
|
super(QuerySetMethodWrapper, self).__init__(method)
|
|
self._model_cls = model_cls
|
|
|
|
def py__get__(self, instance, class_value):
|
|
return ValueSet({QuerySetBoundMethodWrapper(v, self._model_cls)
|
|
for v in self._wrapped_value.py__get__(instance, class_value)})
|
|
|
|
|
|
class QuerySetBoundMethodWrapper(ValueWrapper):
|
|
def __init__(self, method, model_cls):
|
|
super(QuerySetBoundMethodWrapper, self).__init__(method)
|
|
self._model_cls = model_cls
|
|
|
|
def get_signatures(self):
|
|
return _get_signatures(self._model_cls)
|