refactor: contains method in the base class by jupyterjazz · Pull Request #1701 · docarray/docarray · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions docarray/index/abstract.py
15 changes: 5 additions & 10 deletions docarray/index/backends/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,16 +669,11 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], List[Any]]:
def _refresh(self, index_name: str):
self._client.indices.refresh(index=index_name)

def __contains__(self, item: BaseDoc) -> bool:
if safe_issubclass(type(item), BaseDoc):
if len(item.id) == 0:
return False
ret = self._client_mget([item.id])
return ret["docs"][0]["found"]
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
def _doc_exists(self, doc_id: str) -> bool:
if len(doc_id) == 0:
return False
ret = self._client_mget([doc_id])
return ret["docs"][0]["found"]

###############################################
# API Wrappers #
Expand Down
17 changes: 5 additions & 12 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,18 +393,11 @@ def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSche
raise KeyError(f'No document with id {doc_ids} found')
return out_docs

def __contains__(self, item: BaseDoc):
if safe_issubclass(type(item), BaseDoc):
hash_id = self._to_hashed_id(item.id)
self._sqlite_cursor.execute(
f"SELECT data FROM docs WHERE doc_id = '{hash_id}'"
)
rows = self._sqlite_cursor.fetchall()
return len(rows) > 0
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
def _doc_exists(self, doc_id: str) -> bool:
hash_id = self._to_hashed_id(doc_id)
self._sqlite_cursor.execute(f"SELECT data FROM docs WHERE doc_id = '{hash_id}'")
rows = self._sqlite_cursor.fetchall()
return len(rows) > 0

def num_docs(self) -> int:
"""
Expand Down
9 changes: 2 additions & 7 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,8 @@ def _text_search_batched(
) -> _FindResultBatched:
raise NotImplementedError(f'{type(self)} does not support text search.')

def __contains__(self, item: BaseDoc):
if safe_issubclass(type(item), BaseDoc):
return any(doc.id == item.id for doc in self._docs)
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
def _doc_exists(self, doc_id: str) -> bool:
return any(doc.id == doc_id for doc in self._docs)

def persist(self, file: Optional[str] = None) -> None:
"""Persist InMemoryExactNNIndex into a binary file."""
Expand Down
25 changes: 10 additions & 15 deletions docarray/index/backends/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,21 +317,16 @@ def num_docs(self) -> int:
"""
return self._client.count(collection_name=self.collection_name).count

def __contains__(self, item: BaseDoc) -> bool:
if safe_issubclass(type(item), BaseDoc):
response, _ = self._client.scroll(
collection_name=self.index_name,
scroll_filter=rest.Filter(
must=[
rest.HasIdCondition(has_id=[self._to_qdrant_id(item.id)]),
],
),
)
return len(response) > 0
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
def _doc_exists(self, doc_id: str) -> bool:
response, _ = self._client.scroll(
collection_name=self.index_name,
scroll_filter=rest.Filter(
must=[
rest.HasIdCondition(has_id=[self._to_qdrant_id(doc_id)]),
],
),
)
return len(response) > 0

def _del_items(self, doc_ids: Sequence[str]):
items = self._get_items(doc_ids)
Expand Down
17 changes: 1 addition & 16 deletions docarray/index/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _del_items(self, doc_ids: Sequence[str]) -> None:
):
self._client.delete(*batch)

def _doc_exists(self, doc_id) -> bool:
def _doc_exists(self, doc_id: str) -> bool:
"""
Checks if a document exists in the index.

Expand Down Expand Up @@ -610,18 +610,3 @@ def _text_search_batched(
scores.append(results.scores)

return _FindResultBatched(documents=docs, scores=scores)

def __contains__(self, item: BaseDoc) -> bool:
"""
Checks if a given document exists in the index.

:param item: The document to check.
It must be an instance of BaseDoc or its subclass.
:return: True if the document exists in the index, False otherwise.
"""
if safe_issubclass(type(item), BaseDoc):
return self._doc_exists(item.id)
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
31 changes: 13 additions & 18 deletions docarray/index/backends/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,25 +760,20 @@ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
]
return ids

def __contains__(self, item: BaseDoc) -> bool:
if safe_issubclass(type(item), BaseDoc):
result = (
self._client.query.get(self.index_name, ['docarrayid'])
.with_where(
{
"path": ['docarrayid'],
"operator": "Equal",
"valueString": f'{item.id}',
}
)
.do()
)
docs = result["data"]["Get"][self.index_name]
return docs is not None and len(docs) > 0
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
def _doc_exists(self, doc_id: str) -> bool:
result = (
self._client.query.get(self.index_name, ['docarrayid'])
.with_where(
{
"path": ['docarrayid'],
"operator": "Equal",
"valueString": f'{doc_id}',
}
)
.do()
)
docs = result["data"]["Get"][self.index_name]
return docs is not None and len(docs) > 0

class QueryBuilder(BaseDocIndex.QueryBuilder):
def __init__(self, document_index):
Expand Down
3 changes: 3 additions & 0 deletions tests/index/base_classes/test_base_doc_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def python_type_to_db_type(self, x):
def num_docs(self):
return 3

def _doc_exists(self, doc_id: str) -> bool:
return False

_index = _identity
_del_items = _identity
_get_items = _identity
Expand Down
2 changes: 1 addition & 1 deletion tests/index/base_classes/test_configs.py