"""An organization search tools using ROR database.
Research Organization Registry (ROR) is a global registry of open persistent identifiers for research organizations.
We use its database through ROR API to retrieve parent-child relationships between institutions
in order to consolidate outcomes under different sub-organizations.
"""
import asyncio
import json
import os
from collections import defaultdict
import logging
from typing import Annotated, Any, Iterable, Literal, MutableMapping, Type, TypedDict, TypeVar
from urllib.parse import quote as quote_url
from aiohttp import ClientSession, ClientTimeout
from cachetools import LRUCache
from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langgraph.func import entrypoint
from langchain.agents import create_agent
from pydantic import BaseModel, ConfigDict, Field, ValidationError
Ignore = Any
_ELASTIC_OPERATORS = set(r'+-=&|><!(){}[]^"~*?:\/')
_Model = TypeVar("_Model", bound=BaseModel)
OrganizationType = Literal["archive", "company", "education", "facility", "funder",
"government", "healthcare", "nonprofit", "other"] | str
class GeoNames(BaseModel):
model_config = ConfigDict(extra="ignore")
country_code: Annotated[str, Field(pattern=r"^[A-Z]{2,2}$")]
class Location(BaseModel):
geonames_details: GeoNames
geonames_id: int
class Name(BaseModel):
lang: str | None
types: list[str]
value: str
class Relationship(BaseModel):
id: str
label: str
type: str
class Organization(BaseModel):
id: str
status: Literal["active", "inactive", "withdrawn"]
types: list[OrganizationType]
admin: Ignore = None
domains: list[str]
established: Ignore = None
external_ids: Ignore = None
links: Ignore = None
locations: list[Location]
names: list[Name]
relationships: list[Relationship]
def __str__(self):
return repr(self.ror_name) + (f" ({self.id}{'' if self.is_active else ' ( 🗑️ not active)'}"
f" | types={', '.join(self.types) or '?'})")
@property
def first_country_code(self) -> Annotated[str, Field(pattern=r"^[A-Z]{2,2}$")] | None:
"""Return the first Alpha2 country code of this organization if it has."""
if not self.locations:
return None
return self.locations[0].geonames_details.country_code
@property
def is_active(self) -> bool:
"""Whether this organization info is an active record."""
return self.status == "active"
@property
def ror_name(self) -> str:
"""Get the name that tagged with 'ror_display'."""
ror_name = [name for name in self.names if "ror_display" in name.types]
if len(ror_name) != 1:
raise ValueError(f"Expect one name with 'ror_display' tag, but got{len(ror_name)}.")
return ror_name[0].value
@property
def parent(self) -> Relationship | None:
"""Find the parent and returns `None` if not found."""
parents = [org for org in self.relationships if org.type == "parent"] or [None]
return parents[0]
@property
def simplified_dump(self) -> dict:
"""Make a simplified model_dump for LLM."""
ror_name = self.ror_name
return {
"ror_id": self.id,
"ror_name": ror_name,
"aliases": [name.value for name in self.names if name.value != ror_name],
"country_code": self.first_country_code,
}
class Match(BaseModel):
score: Annotated[float, Field(ge=0)]
matching_type: Literal["EXACT", "FUZZY", "PARTIAL"] | str
chosen: bool
organization: Organization
substring: Ignore = None
def __str__(self):
return f"({self.score:.2f}, {'chosen' if self.chosen else ' '}) {self.organization}"
@classmethod
def merge_organization(cls: Type[_Model], parent: Organization, children: "list[Match]") -> _Model:
"""Merge weight of a list of child organizations into their parent."""
return Match(
score=sum(child.score for child in children) if children else 0,
matching_type=min(child.matching_type for child in children) if children else "EXACT",
chosen=any(child.chosen for child in children) if children else True,
organization=parent
)
def merge(self, other: "Match"):
"""Merge another match record into this record."""
self.score += other.score
self.chosen |= other.chosen
return self
class RORMatchResponse(BaseModel):
number_of_results: int
items: list[Match]
_ror_cache: MutableMapping[str, Organization] = LRUCache(maxsize=1024)
def _get_api_base():
return os.environ.get("ROR_API_BASE") or "https://api.ror.org/"
def _get_verify_env():
return os.environ.get("ROR_VERIFY_SSL") in ("0", "FALSE", "False", "false")
class RORQueryResponse(BaseModel):
items: list[Organization]
meta: Ignore
class RORClient(BaseModel):
class RateLimit(Exception):
"""Exception when received rate limited error from ROR API."""
_HAS_KEY_MSG = "ROR rate limit!"
_NO_KEY_MSG = _HAS_KEY_MSG + " Add an ROR Client Key to lift rate limits."
def __init__(self, has_key: bool):
super().__init__(self._HAS_KEY_MSG if has_key else self._NO_KEY_MSG)
verify_ssl: Annotated[bool, Field(default_factory=_get_verify_env)]
client_id: str = None
max_retry_per_request: int = 3
api_base: Annotated[str, Field(default_factory=_get_api_base)]
@property
def _headers(self):
return {"Client-Id": self.client_id} if self.client_id else {}
def __str__(self):
return f"<ROR client to {self.api_base!r} with{'' if self.client_id else 'out'} client ID>"
@staticmethod
def _escape(url: str, **kwargs: str) -> str:
escaped_args = {}
for key, value in kwargs.items():
es_escape = "".join(f"\\{s}" if s in _ELASTIC_OPERATORS else s for s in value)
escaped_args[key] = quote_url(es_escape, safe="")
return url.format(**escaped_args)
@staticmethod
def _format_organizations(title: str, orgs: list[Organization]) -> str:
if not orgs:
return f"{title}: []\n"
return f"{title}:\n- " + "\n- ".join(str(o) for o in orgs) + "\n"
@staticmethod
def _format_organizations_map(mapping: Iterable[tuple[list[Any], str]]) -> str:
ret = ""
for group, map_to in mapping:
group = [str(item) for item in group]
max_child_len = max(len(org) for org in group)
ret += f"{group[0]:<{max_child_len}} ---> {map_to}\n"
for org in group[1:-1]:
ret += f"{org:<{max_child_len}} -|\n"
if len(group) > 1:
ret += f"{group[-1]:<{max_child_len}} -/\n"
return ret or "[] (all resolved)"
@staticmethod
def _merge_roots(existing: dict[str, Match], new: dict[str, Match]) -> dict[str, Match]:
for ror_id, org in new.items():
if ror_id in existing:
existing[ror_id].merge(org)
else:
existing[ror_id] = org
return existing
@classmethod
def _resolve_parents(cls, grouped_children: dict[str, list[Match]], parents: dict[str, Organization | Exception],
existing_roots: dict[str, Match],
root_follow: Iterable[str]) -> tuple[dict[str, Match], list[Match]]:
statistic_for_log: list[tuple[list[Match], str]] = []
may_new_roots: list[Match] = []
forks: list[Match] = []
for parent_id, parent_or_exc in parents.items():
children = grouped_children[parent_id]
if isinstance(parent_or_exc, Exception):
statistic_for_log.append((children, f"❌ {parent_id} ({type(parent_or_exc).__name__}"))
may_new_roots.extend(children)
elif not any(t in root_follow for t in parent_or_exc.types):
statistic_for_log.append(
(children, f"↩️ {parent_id} {parent_or_exc.types} not in any of {tuple(root_follow)}")
)
may_new_roots.extend(children)
elif not parent_or_exc.is_active:
statistic_for_log.append((children, f"↩️ {parent_id} not activate"))
may_new_roots.extend(children)
else:
parent = Match.merge_organization(parent=parent_or_exc, children=children)
statistic_for_log.append((children, f"✅ {parent}"))
(forks if parent_or_exc.parent else may_new_roots).append(parent)
logged_map: list[tuple[list, str]] = []
for group, parent in statistic_for_log:
logged_map.append(([item for item in group], parent))
logging.info(f"Resolved parent relationships:\n{cls._format_organizations_map(logged_map)}")
for match in may_new_roots:
if match.organization.id in existing_roots:
existing_roots[match.organization.id].merge(match)
else:
existing_roots[match.organization.id] = match
return existing_roots, forks
async def fetch_one(self, session: ClientSession, ror_id: str) -> Organization:
"""Fetch one record from ROR."""
id_str = ror_id.split("/")[-1]
url = self._escape("/v2/organizations/{id}", id=id_str)
ret = await self.__request_with_retry(session, "GET", url, out_model=Organization,
usage_for_log=f"Fetch {ror_id}")
logging.info(f"Fetch {ror_id} ends with record: {ret}")
return ret
async def match(self, organization_name: str,
find_root=True, root_follow: Iterable[str] = frozenset(["education", "company"]),
follow_not_chosen=False, min_follow_score: float = None) -> tuple[list[Match], list[Match]]:
"""Trying match the given `organization_name` into some ROR record and resolved to their root organization.
Returns a tuple of (first match, resolved result).
"""
async with self._create_session() as session:
first_match = await self.match_request(session, organization_name)
for match in first_match:
_ror_cache[match.organization.id] = match.organization
if not find_root:
return first_match, first_match
return first_match, await self._find_root_nodes(first_match, organization_name, root_follow,
follow_not_chosen, min_follow_score, session)
async def match_one_or_origin(
self, organization_name: str,
find_root=True, root_follow: Iterable[str] = frozenset(["education", "company"]),
follow_not_chosen=False, min_follow_score: float = None, llm: BaseChatModel = None) -> str:
"""Trying match the given `organization_name` into one ROR record and return the origin name if failed."""
try:
matches = await self.match(organization_name, find_root, root_follow, follow_not_chosen, min_follow_score)
except Exception as e:
logging.error(f"Matching {organization_name!r} failed with Exception and returns its origin name: {e}",
exc_info=True)
return organization_name
origin_match, final_match = matches
if len(final_match) != 1:
if not any(match.organization.ror_name == organization_name for match in final_match):
if llm:
return await self._match_by_llm(origin_match, organization_name, llm, root_follow)
logging.warning(f"Matching {organization_name!r} with {len(matches)} results (expected to be 1) "
"and returns its origin name.")
return organization_name
return final_match[0].organization.ror_name
async def match_request(self, session: ClientSession, name: str) -> list[Match]:
"""Make a simple request to ROR and returns its raw result."""
url = self._escape("/v2/organizations?affiliation={name}", name=name)
all_records = (await self.__request_with_retry(session, "GET", url, out_model=RORMatchResponse,
usage_for_log=f"Match {name!r}")).items
ret = []
log_str = f"Match {name!r} got {len(all_records)} results:"
for match in all_records:
log_str += f"\n- ({match}"
if match.organization.is_active:
ret.append(match)
if not len(all_records):
log_str += " []"
logging.info(log_str)
return ret
def _create_session(self) -> ClientSession:
return ClientSession(base_url=self.api_base, timeout=ClientTimeout(connect=10, sock_read=20), trust_env=True)
def _extract_parents(self, children: list[Match], query: str, depth: int,
follow_not_chosen=False, min_follow_score: float = None
) -> tuple[dict[str, list[Match]], dict[str, Match]]:
"""Returns a dict meaning (parent.id, child organizations) and a list of root organizations."""
groups: dict[str, list[Match]] = defaultdict(list)
parents: dict[str, Relationship] = {}
root_nodes: list[Match] = []
dropped: list[Match] = []
for item in children:
if not item.chosen:
if (not follow_not_chosen) or (min_follow_score is not None and item.score < min_follow_score):
dropped.append(item)
continue
if not item.organization.parent:
root_nodes.append(item)
continue
groups[item.organization.parent.id].append(item)
parents[item.organization.parent.id] = item.organization.parent
log_str = f"Query {query!r} and resolving parent relation ship for the {depth} time.\n"
if not dropped:
log_str += "Dropped: []\n"
else:
log_str += f"Dropped:\n- " + "\n- ".join(f"({o.score}) {o.organization}" for o in dropped) + "\n"
if root_nodes:
log_str += "Root nodes:\n- " + "\n- ".join(str(match.organization) for match in root_nodes) + "\n"
log_str += "Relationships:\n"
mapping: list[tuple[list, str]] = []
for parent in sorted(parents.values(), key=lambda p: p.label):
orgs = [match.organization for match in groups[parent.id]]
map_to = f"{'⬇️' if parent.id not in _ror_cache else '✅'}{parent.id} ({parent.label!r})"
mapping.append((orgs, map_to))
log_str += self._format_organizations_map(mapping)
logging.info(log_str)
return groups, {match.organization.id: match for match in root_nodes}
async def _fetch_records(self, session: ClientSession,
ror_ids: Iterable[str]) -> dict[str, Organization | BaseException]:
ror_ids = set(ror_ids)
existing: dict[str, Organization | Exception] = {id_: _ror_cache.get(id_) for id_ in ror_ids}
existing = {k: v for k, v in existing.items() if v is not None}
if existing:
logging.info(f"These ROR items are cached: {list(existing)}")
miss_ids: list[str] = list(set(ror_ids) - set(existing))
if not miss_ids:
return existing
records = await asyncio.gather(*(self.fetch_one(session, id_) for id_ in miss_ids), return_exceptions=True)
for id_, record in zip(miss_ids, records):
if isinstance(record, self.RateLimit):
raise record
elif isinstance(record, Organization):
_ror_cache[record.id] = record
existing[id_] = record
return existing
async def _find_root_nodes(self, first_match: list[Match], organization_name: str, root_follow: Iterable[str],
follow_not_chosen: bool, min_follow_score: float, session: ClientSession) -> list[Match]:
relations, roots = self._extract_parents(
first_match, query=organization_name, depth=1, follow_not_chosen=follow_not_chosen,
min_follow_score=min_follow_score
)
existing_parents = await self._fetch_records(session, relations)
roots, forks = self._resolve_parents(relations, existing_parents, roots, root_follow)
depth = 2
while forks:
parent_relations, new_roots = self._extract_parents(
forks, query=organization_name, depth=depth,
follow_not_chosen=follow_not_chosen, min_follow_score=min_follow_score
)
roots = self._merge_roots(roots, new_roots)
new_parents = await self._fetch_records(session, parent_relations)
roots, forks = self._resolve_parents(parent_relations, new_parents, roots, root_follow)
return list(roots.values())
async def _match_by_llm(self, first_match: list[Match], org_name: str, llm: BaseChatModel,
root_follow: Iterable[str]) -> str:
from langfuse.langchain import CallbackHandler
langfuse_handler = CallbackHandler()
async with self._create_session() as session:
inputs = _MatchByLLMInput(first_match=first_match, org_name=org_name)
input_as_configs = RunnableConfig(configurable=dict(llm=llm, session=session, client=self))
try:
org: Organization = await (
_match_by_llm
.with_config(run_name="match_ROR_by_LLM", callbacks=[langfuse_handler])
.ainvoke(inputs, config=input_as_configs)
)
except RORException:
return org_name
as_match = Match.merge_organization(org, [])
as_match.chosen = True
as_match.score = 1.
roots = await self._find_root_nodes([as_match], org_name, root_follow, follow_not_chosen=True,
min_follow_score=0., session=session)
return roots[0].organization.ror_name
async def __request_with_retry(self, session: ClientSession, method: str, path_with_query: str,
out_model: Type[_Model], usage_for_log: str) -> _Model:
"""Success msgs is not logged."""
last_exception: Exception = RuntimeError(f"Unknown exception when {usage_for_log} from ROR.")
for retry_count in range(1, self.max_retry_per_request + 1):
try:
response = await session.request(method, url=self.api_base + path_with_query, headers=self._headers,
ssl=self.verify_ssl)
if response.status == 429:
raise RORClient.RateLimit(bool(self.client_id))
response.raise_for_status()
return out_model.model_validate(await response.json())
except RORClient.RateLimit:
raise
except Exception as e:
last_exception = e
logging.error(f"Failed to {usage_for_log} from ROR for the {retry_count} time with "
f"{type(e).__name__}: {e}", exc_info=True)
logging.error(f"Failed to {usage_for_log} for too many times ({self.max_retry_per_request})!"
" Aborted with last exception.")
raise last_exception
class RORException(RuntimeError):
"""A flag that known Exception handled in inner code."""
class _LLMSelectResult(BaseModel):
model_config = ConfigDict(extra="forbid")
ror_id: str = None
class _MatchByLLMInput(TypedDict):
first_match: list[Match]
org_name: str
@entrypoint()
async def _match_by_llm(inputs: _MatchByLLMInput, config: RunnableConfig) -> Organization:
"""Return an organization matched by LLM with retry."""
first_match = inputs["first_match"]
org_name = inputs["org_name"]
llm: BaseChatModel = config["configurable"]["llm"]
session: ClientSession = config["configurable"]["session"]
client: RORClient = config["configurable"]["client"]
sub_config = RunnableConfig(configurable=_ToolConf(ror_client=client, ror_session=session))
agent = create_agent(llm, tools=[_ror_search], system_prompt=_MATCH_ONE_ROR_SYS_PROMPT_TEXT)
references = json.dumps([match.organization.simplified_dump for match in first_match], ensure_ascii=False, indent=2)
max_retry = 3
for _ in range(max_retry):
try:
out_msgs = await agent.ainvoke(
{"messages": [
{
"role": "user",
"content": f"""## Reference Records\n\n{references}\n\n## Target Organization\n\n{org_name}"""
}
]}, config=sub_config)
out_text = out_msgs["messages"][-1].content
left = out_text.find("{")
right = out_text.rfind("}")
if left == -1 or right == -1:
logging.error(f"LLM生成的{out_text!r}不包含完整的json对象")
raise RORException("查询机构ROR信息时发生异常")
json_text = out_text[left:right+1]
try:
out: _LLMSelectResult = _LLMSelectResult.model_validate_json(json_text)
except ValidationError as e:
logging.error(f"查询机构信息时发生异常:{e}。LLM完整输出为:{out_text!r},其中识别到的json内容为{json_text!r}")
continue
if not out.ror_id:
logging.warning(f"LLM match {org_name} returns nothing.")
continue
return await client.fetch_one(session, out.ror_id)
except Exception as e:
logging.error(f"Matching {org_name!r} failed with unknown {type(e).__name__}: {e}", exc_info=True)
logging.warning(f"Try matching {org_name!r} by LLM failed for too many times ({max_retry}), returns origin.")
raise RORException(f"Failed for too many times ({max_retry}")
class _ToolConf(TypedDict):
ror_client: RORClient
ror_session: ClientSession
@tool("ror_search", parse_docstring=True, error_on_invalid_docstring=True)
async def _ror_search(org_name: str, config: RunnableConfig) -> str:
"""Search `org_name` to match recorded organization name in ROR database.
Args:
org_name: str, the target organization name.
Returns:
Matched results with their recorded name, aliases, country (in ISO3166-1 Alpha2 code) and ROR ID in a list.
"""
client: RORClient = config["configurable"]["ror_client"]
session: ClientSession = config["configurable"]["ror_session"]
matches = await client.match_request(session, org_name)
result = [match.organization.simplified_dump for match in matches]
return json.dumps(result, ensure_ascii=False, indent=2)
_MATCH_ONE_ROR_SYS_PROMPT_TEXT = """## Role
You are an Academic Affiliation Retrieval Expert.
Your task is to find an organization record that represents the same organization as the the name \
provided by the user, (or a parent organization of the one that user is inquiring about) based on \
"ror_name" and "aliases", and return its ROR ID.
## Task
1. Check whether any organization in the references list of ROR organizations (based on "ror_name" and "aliases") \
matches the target organization that the user is inquiring about. If so, return its ROR ID directly.
2. If none of the existing references meet the wanted organization, call tool "ror_search" for a search, and perform \
further matching based on the search results.
3. If tool call fails, or if no matching organization record is found after more than 3 tool calls, stop and return \
an empty JSON.
## Notice
If the target organization is a multinational entity and there are existing records of its branches in other \
regions, you can still regard the record as a successful match and return its ROR ID.
If there are multiple branches of the organization in the records, you only need to output ROR ID of any one of them.
However, if there is a headquarters, you should directly output the ROR ID of the headquarters.
## Output Format
Return your answer strictly following this JSON structure:
{{
"ror_id": "",
}}
---
## Example 1
### Input
target: "Huawei Cloud"
references: [
{{"ror_name": "Huawei Technologies (Poland)", "ror_id": "https://ror.org/007a2ta87"}},
{{"ror_name": "Huawei Technologies (Sweden)", "ror_id": "https://ror.org/0500fyd17"}}
]
### Output ("Huawei Cloud" is subsidiary of "Huawei Technologies" and has a record of being its Polish branch)
{{
"ror_id": "https://ror.org/007a2ta87"
}}
## Example 2
### Input
target: "HUAWEI"
references: []
### Tool Output on "Huawei"
references: [
{{"ror_name": "Huawei Technologies (Poland)", "ror_id": "https://ror.org/007a2ta87"}},
{{"ror_name": "Huawei Technologies (China)", "ror_id": "https://ror.org/00cmhce21"}}
]
### Output ("China" is headquarters of "Huawei Technologies")
{{
"ror_id": "https://ror.org/00cmhce21"
}}
"""