Commit fc672343 authored by Marko Kollo's avatar Marko Kollo 😄
Browse files

Merge branch 'on_demand_test' into 'master'

Tests for on-deman loading.

See merge request !14
parents f54be4bc aa11932f
Pipeline #6210 passed with stage
in 7 minutes and 46 seconds
import json
import pytest
import regex as re
from texta_mlp.entity_mapper import EntityMapper
from texta_mlp.mlp import MLP
......@@ -236,3 +238,22 @@ def test_parsing_empty_list_in_dictionary(mlp: MLP):
result = mlp.process_docs([{"empty_list_field": []}], doc_paths=["empty_list_field"])
for key in result:
assert "mlp" not in key
def test_that_models_are_loaded_on_demand():
pipeline = MLP(language_codes=["et", "en"], logging_level="info", use_gpu=False)
stanza_pipelines = pipeline._stanza_pipelines
assert len(stanza_pipelines.keys()) == 0
result = pipeline.process(raw_text="Tere, minu nimi on Joonas, kas saaksite öelda, mis kell praegu on?", analyzers=["lemmas"], lang="et")
assert "et" in pipeline._stanza_pipelines
assert len(pipeline._stanza_pipelines.keys()) == 1
result = pipeline.process(raw_text="Hello there, my name is Joonas, how do you do!?", analyzers=["lemmas"])
assert "en" in pipeline._stanza_pipelines
assert len(pipeline._stanza_pipelines) == 2
def test_that_entity_mapper_is_loaded_on_demand():
pipeline = MLP(language_codes=["et"], logging_level="info", use_gpu=False)
assert pipeline._entity_mapper is None
pipeline.process(raw_text="Tere, minu nimi on Joonas, kas saaksite öelda, mis kell praegu on?", analyzers=["entities"], lang="et")
assert isinstance(pipeline._entity_mapper, EntityMapper)
......@@ -81,7 +81,7 @@ class MLP:
self.use_default_lang = use_default_language_code
self.resource_dir = resource_dir
self.__stanza_pipelines = {}
self._stanza_pipelines = {}
self.logging_level = logging_level
self.use_gpu = use_gpu
self.stanza_resource_path = pathlib.Path(self.resource_dir) / "stanza"
......@@ -93,7 +93,7 @@ class MLP:
self.prepare_resources(refresh_data)
self.__entity_mapper = None
self._entity_mapper = None
self.loaded_entity_files = []
self.not_entities = self._load_not_entities()
......@@ -235,15 +235,14 @@ class MLP:
def get_entity_mapper(self):
if self.__entity_mapper is None:
self.__entity_mapper = self._load_entity_mapper()
return self.__entity_mapper
if self._entity_mapper is None:
self._entity_mapper = self._load_entity_mapper()
return self._entity_mapper
def get_stanza_pipeline(self, lang: str):
if lang not in self.__stanza_pipelines:
if lang not in self._stanza_pipelines:
try:
self.__stanza_pipelines[lang] = stanza.Pipeline(
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
......@@ -253,7 +252,7 @@ class MLP:
# This is for CUDA OOM exceptions. Fall back to CPU if needed.
except RuntimeError:
self.__stanza_pipelines[lang] = stanza.Pipeline(
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
......@@ -261,7 +260,7 @@ class MLP:
logging_level=self.logging_level,
)
return self.__stanza_pipelines[lang]
return self._stanza_pipelines[lang]
def _get_stanza_tokens(self, lang: str, raw_text: str):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment